296 lines
11 KiB
Python
296 lines
11 KiB
Python
from copy import copy
|
|
from discord.ext import commands
|
|
from enum import Enum
|
|
import shlex
|
|
import typing
|
|
|
|
from .utils import Embeds, role_names
|
|
|
|
|
|
class AlternativeStoreType:
|
|
|
|
class _StoreType:
|
|
pass
|
|
|
|
class Bool(_StoreType):
|
|
pass
|
|
|
|
class Int(_StoreType):
|
|
def __init__(self, min: int = float('-inf'), max: int = float('inf')):
|
|
if min > max:
|
|
raise ValueError('min must be higher than max')
|
|
self.min = min
|
|
self.max = max
|
|
|
|
class Value(_StoreType):
|
|
pass
|
|
|
|
store_bool = Bool()
|
|
store_int = Int()
|
|
store_value = Value()
|
|
|
|
|
|
class StoreType(Enum):
|
|
store_bool = 'store_bool'
|
|
store_value = 'store_value'
|
|
|
|
|
|
class _Parsed:
|
|
def __init__(self):
|
|
self.normal_args = ()
|
|
self.ctx: commands.Context = None
|
|
|
|
|
|
class FlagErrors:
|
|
|
|
class _FlagError(Exception):
|
|
|
|
def __init__(self, message: str = None, flag: str = None):
|
|
self.flag = flag
|
|
super().__init__(message)
|
|
|
|
class DoubleFlagError(_FlagError):
|
|
pass
|
|
|
|
class FlagParseError(_FlagError):
|
|
pass
|
|
|
|
class FlagPositionError(_FlagError):
|
|
|
|
def __init__(self, message: str = None, flag: str = None):
|
|
super().__init__(message, flag)
|
|
|
|
class FlagStoreError(_FlagError):
|
|
|
|
def __init__(self, message: str = None, flag: str = None, store_type: StoreType = None, value=None):
|
|
self.store_type = store_type
|
|
self.value = value
|
|
super().__init__(message, flag)
|
|
|
|
class RoleNotAllowedError(_FlagError):
|
|
|
|
def __init__(self, message: str = None, flag: str = None, role: str = None):
|
|
self.role = role
|
|
super().__init__(message, flag)
|
|
|
|
class QuotationError(Exception):
|
|
|
|
def __init__(self, message: str = None):
|
|
super().__init__(message)
|
|
|
|
|
|
class _ParsedFlag:
|
|
|
|
def __init__(self, flag: str, value):
|
|
self.flag = flag
|
|
self._value = value
|
|
|
|
def __new__(cls, flag: str, value):
|
|
return value
|
|
|
|
def __add__(self, other):
|
|
return self._value + other
|
|
|
|
def __bool__(self):
|
|
return True if self._value else False
|
|
|
|
def __ge__(self, other):
|
|
return self._value >= other
|
|
|
|
def __getitem__(self, item):
|
|
return self._value
|
|
|
|
def __gt__(self, other):
|
|
return self._value < other
|
|
|
|
def __hash__(self):
|
|
return hash((self.flag, self._value))
|
|
|
|
def __le__(self, other):
|
|
return self._value <= other
|
|
|
|
def __len__(self):
|
|
return len(self._value)
|
|
|
|
def __lt__(self, other):
|
|
return self._value < other
|
|
|
|
def __repr__(self):
|
|
return self._value
|
|
|
|
|
|
class Flag:
|
|
|
|
def __init__(self, raise_errors=True,
|
|
double_flag_error='The flag `{flag}` was already given',
|
|
flag_position_error='Flags must be called at the end of a command',
|
|
quotation_error='At least one quotation mark is missing (`"` or `\'`)'):
|
|
self.double_flag_error = double_flag_error
|
|
self.raise_errors = raise_errors
|
|
self.flag_position_error = flag_position_error
|
|
self.quotation_error = quotation_error
|
|
self._flags = {}
|
|
|
|
def add_flag(self, *flags,
|
|
store_type: typing.Union[StoreType, typing.Tuple[StoreType, typing.Any]],
|
|
parser: typing.Callable = None,
|
|
allowed_roles: typing.Union[typing.List[str], typing.Tuple[str]] = [],
|
|
disallowed_roles: typing.Union[typing.List[str], typing.Tuple[str], typing.Dict[str, str]] = [],
|
|
not_allowed_role_message='User has no allowed role for flag {flag}',
|
|
wrong_store_type_message=None,
|
|
help: str = '', show_help=True):
|
|
if store_type == StoreType.store_bool and parser is not None:
|
|
raise FlagErrors.FlagParseError('The flag parser cannot be set if the store type is \'store_bool\'')
|
|
|
|
if store_type == StoreType.store_bool:
|
|
default_value = False
|
|
else:
|
|
default_value = None
|
|
|
|
for allowed in allowed_roles:
|
|
if allowed in disallowed_roles:
|
|
raise ValueError(f'Role `{allowed}` cannot be allowed and disallowed at the same time')
|
|
|
|
flag_information = {'store_type': store_type[0] if isinstance(store_type, tuple) else store_type, 'default': store_type[1] if isinstance(store_type, tuple) else default_value,
|
|
'parser': parser,
|
|
'allowed_roles': allowed_roles, 'disallowed_roles': disallowed_roles,
|
|
'not_allowed_role_message': not_allowed_role_message,
|
|
'wrong_store_type_message': wrong_store_type_message,
|
|
'help': help, 'show_help': show_help}
|
|
|
|
for flag in flags:
|
|
flag = str(flag)
|
|
self._flags[flag] = flag_information
|
|
|
|
def flags(self) -> typing.Dict[str, typing.Any]:
|
|
return copy(self._flags)
|
|
|
|
async def parse(self, args: str, ctx: commands.Context):
|
|
# (re)sets the attributes every time
|
|
parsed = _Parsed()
|
|
|
|
try:
|
|
shlex_args = shlex.split(args)
|
|
except ValueError as error:
|
|
if str(error) == 'No closing quotation':
|
|
if self.raise_errors:
|
|
raise FlagErrors.QuotationError(self.quotation_error)
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=self.quotation_error))
|
|
return
|
|
else:
|
|
raise error
|
|
|
|
for flag, information in self._flags.items():
|
|
parsed.__setattr__(flag[2:], information['default'])
|
|
|
|
flag_indexed = False
|
|
normal_args = []
|
|
parsed_flags = []
|
|
roles = role_names(ctx.author)
|
|
|
|
for i, arg in enumerate(shlex_args):
|
|
arg = str(arg).replace('"', '').replace("'", '')
|
|
if '=' in arg:
|
|
arg, value = arg.split('=', 1)
|
|
else:
|
|
value = None
|
|
|
|
if arg in self._flags:
|
|
if arg in parsed_flags:
|
|
if self.raise_errors:
|
|
raise FlagErrors.DoubleFlagError(self.double_flag_error.format(flag=arg))
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=self.double_flag_error.format(flag=arg)))
|
|
return
|
|
else:
|
|
parsed_flags.append(arg)
|
|
|
|
if not flag_indexed:
|
|
flag_indexed = True
|
|
flag = self._flags[arg]
|
|
|
|
# --- #
|
|
allowed_roles = flag['allowed_roles']
|
|
if allowed_roles:
|
|
if not any(allowed in roles for allowed in allowed_roles):
|
|
error = flag['not_allowed_role_message'].format(flag=arg)
|
|
if self.raise_errors:
|
|
raise FlagErrors.RoleNotAllowedError(error)
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=error))
|
|
return
|
|
disallowed_roles = flag['disallowed_roles']
|
|
if disallowed_roles:
|
|
for disallowed in disallowed_roles:
|
|
if disallowed in roles:
|
|
error = disallowed_roles[disallowed] if isinstance(disallowed_roles, dict) else 'The role `{role}` is not allowed to use the {flag} flag'
|
|
error = error.format(role=disallowed, flag=arg)
|
|
if self.raise_errors:
|
|
raise FlagErrors.RoleNotAllowedError(message=error, flag=arg, role=disallowed)
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=error))
|
|
return
|
|
|
|
store_type = flag['store_type']
|
|
arg_without_prefix = arg[2:]
|
|
if store_type == StoreType.store_bool:
|
|
error = flag['wrong_store_type_message'] if flag['wrong_store_type_message'] else 'Flag `{flag}` must not contain a value'
|
|
error = error.format(flag=arg)
|
|
if value:
|
|
if self.raise_errors:
|
|
raise FlagErrors.FlagStoreError(message=error, flag=arg, store_type=store_type, value=value)
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=error))
|
|
return
|
|
else:
|
|
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, True))
|
|
elif store_type == StoreType.store_value:
|
|
error = flag['wrong_store_type_message'] if flag['wrong_store_type_message'] else 'Flag `{flag}` must not contain a value'
|
|
error = error.format(flag=arg)
|
|
if not value:
|
|
if self.raise_errors:
|
|
raise FlagErrors.FlagStoreError(message=error, flag=arg, store_type=store_type)
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=error))
|
|
return
|
|
else:
|
|
if parser := flag['parser']:
|
|
value_parsed = await parser(ctx, arg, value)
|
|
if isinstance(value_parsed, bool):
|
|
if value_parsed:
|
|
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, True))
|
|
else:
|
|
return
|
|
else:
|
|
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, value_parsed))
|
|
else:
|
|
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, value))
|
|
|
|
elif flag_indexed:
|
|
if self.raise_errors:
|
|
raise FlagErrors.FlagPositionError(message=self.flag_position_error)
|
|
else:
|
|
await ctx.send(embed=Embeds.error_embed(description=self.flag_position_error))
|
|
return
|
|
|
|
else:
|
|
normal_args.append(arg)
|
|
|
|
parsed.normal_args = tuple(normal_args)
|
|
parsed.ctx = ctx
|
|
return parsed
|
|
|
|
|
|
def get_flags(command: commands.Command) -> typing.Union[Flag, None]:
|
|
flags = command.__original_kwargs__.get('flags', None)
|
|
if isinstance(flags, str):
|
|
try:
|
|
return command.cog.__getattribute__(flags)
|
|
except AttributeError:
|
|
raise AttributeError(f'The flag `{flags}` does not exist')
|
|
elif isinstance(flags, Flag):
|
|
return flags
|
|
else:
|
|
return None
|