from copy import copy from discord.ext import commands from enum import Enum import shlex import typing from .utils import Embeds, role_names class AlternativeStoreType: class _StoreType: pass class Bool(_StoreType): pass class Int(_StoreType): def __init__(self, min: int = float('-inf'), max: int = float('inf')): if min > max: raise ValueError('min must be higher than max') self.min = min self.max = max class Value(_StoreType): pass store_bool = Bool() store_int = Int() store_value = Value() class StoreType(Enum): store_bool = 'store_bool' store_value = 'store_value' class _Parsed: def __init__(self): self.normal_args = () self.ctx: commands.Context = None class FlagErrors: class _FlagError(Exception): def __init__(self, message: str = None, flag: str = None): self.flag = flag super().__init__(message) class DoubleFlagError(_FlagError): pass class FlagParseError(_FlagError): pass class FlagPositionError(_FlagError): def __init__(self, message: str = None, flag: str = None): super().__init__(message, flag) class FlagStoreError(_FlagError): def __init__(self, message: str = None, flag: str = None, store_type: StoreType = None, value=None): self.store_type = store_type self.value = value super().__init__(message, flag) class RoleNotAllowedError(_FlagError): def __init__(self, message: str = None, flag: str = None, role: str = None): self.role = role super().__init__(message, flag) class QuotationError(Exception): def __init__(self, message: str = None): super().__init__(message) class _ParsedFlag: def __init__(self, flag: str, value): self.flag = flag self._value = value def __new__(cls, flag: str, value): return value def __add__(self, other): return self._value + other def __bool__(self): return True if self._value else False def __ge__(self, other): return self._value >= other def __getitem__(self, item): return self._value def __gt__(self, other): return self._value < other def __hash__(self): return hash((self.flag, self._value)) def __le__(self, other): return self._value <= other def __len__(self): return len(self._value) def __lt__(self, other): return self._value < other def __repr__(self): return self._value class Flag: def __init__(self, raise_errors=True, double_flag_error='The flag `{flag}` was already given', flag_position_error='Flags must be called at the end of a command', quotation_error='At least one quotation mark is missing (`"` or `\'`)'): self.double_flag_error = double_flag_error self.raise_errors = raise_errors self.flag_position_error = flag_position_error self.quotation_error = quotation_error self._flags = {} def add_flag(self, *flags, store_type: typing.Union[StoreType, typing.Tuple[StoreType, typing.Any]], parser: typing.Callable = None, allowed_roles: typing.Union[typing.List[str], typing.Tuple[str]] = [], disallowed_roles: typing.Union[typing.List[str], typing.Tuple[str], typing.Dict[str, str]] = [], not_allowed_role_message='User has no allowed role for flag {flag}', wrong_store_type_message=None, help: str = '', show_help=True): if store_type == StoreType.store_bool and parser is not None: raise FlagErrors.FlagParseError('The flag parser cannot be set if the store type is \'store_bool\'') if store_type == StoreType.store_bool: default_value = False else: default_value = None for allowed in allowed_roles: if allowed in disallowed_roles: raise ValueError(f'Role `{allowed}` cannot be allowed and disallowed at the same time') flag_information = {'store_type': store_type[0] if isinstance(store_type, tuple) else store_type, 'default': store_type[1] if isinstance(store_type, tuple) else default_value, 'parser': parser, 'allowed_roles': allowed_roles, 'disallowed_roles': disallowed_roles, 'not_allowed_role_message': not_allowed_role_message, 'wrong_store_type_message': wrong_store_type_message, 'help': help, 'show_help': show_help} for flag in flags: flag = str(flag) self._flags[flag] = flag_information def flags(self) -> typing.Dict[str, typing.Any]: return copy(self._flags) async def parse(self, args: str, ctx: commands.Context): # (re)sets the attributes every time parsed = _Parsed() try: shlex_args = shlex.split(args) except ValueError as error: if str(error) == 'No closing quotation': if self.raise_errors: raise FlagErrors.QuotationError(self.quotation_error) else: await ctx.send(embed=Embeds.error_embed(description=self.quotation_error)) return else: raise error for flag, information in self._flags.items(): parsed.__setattr__(flag[2:], information['default']) flag_indexed = False normal_args = [] parsed_flags = [] roles = role_names(ctx.author) for i, arg in enumerate(shlex_args): arg = str(arg).replace('"', '').replace("'", '') if '=' in arg: arg, value = arg.split('=', 1) else: value = None if arg in self._flags: if arg in parsed_flags: if self.raise_errors: raise FlagErrors.DoubleFlagError(self.double_flag_error.format(flag=arg)) else: await ctx.send(embed=Embeds.error_embed(description=self.double_flag_error.format(flag=arg))) return else: parsed_flags.append(arg) if not flag_indexed: flag_indexed = True flag = self._flags[arg] # --- # allowed_roles = flag['allowed_roles'] if allowed_roles: if not any(allowed in roles for allowed in allowed_roles): error = flag['not_allowed_role_message'].format(flag=arg) if self.raise_errors: raise FlagErrors.RoleNotAllowedError(error) else: await ctx.send(embed=Embeds.error_embed(description=error)) return disallowed_roles = flag['disallowed_roles'] if disallowed_roles: for disallowed in disallowed_roles: if disallowed in roles: error = disallowed_roles[disallowed] if isinstance(disallowed_roles, dict) else 'The role `{role}` is not allowed to use the {flag} flag' error = error.format(role=disallowed, flag=arg) if self.raise_errors: raise FlagErrors.RoleNotAllowedError(message=error, flag=arg, role=disallowed) else: await ctx.send(embed=Embeds.error_embed(description=error)) return store_type = flag['store_type'] arg_without_prefix = arg[2:] if store_type == StoreType.store_bool: error = flag['wrong_store_type_message'] if flag['wrong_store_type_message'] else 'Flag `{flag}` must not contain a value' error = error.format(flag=arg) if value: if self.raise_errors: raise FlagErrors.FlagStoreError(message=error, flag=arg, store_type=store_type, value=value) else: await ctx.send(embed=Embeds.error_embed(description=error)) return else: parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, True)) elif store_type == StoreType.store_value: error = flag['wrong_store_type_message'] if flag['wrong_store_type_message'] else 'Flag `{flag}` must not contain a value' error = error.format(flag=arg) if not value: if self.raise_errors: raise FlagErrors.FlagStoreError(message=error, flag=arg, store_type=store_type) else: await ctx.send(embed=Embeds.error_embed(description=error)) return else: if parser := flag['parser']: value_parsed = await parser(ctx, arg, value) if isinstance(value_parsed, bool): if value_parsed: parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, True)) else: return else: parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, value_parsed)) else: parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, value)) elif flag_indexed: if self.raise_errors: raise FlagErrors.FlagPositionError(message=self.flag_position_error) else: await ctx.send(embed=Embeds.error_embed(description=self.flag_position_error)) return else: normal_args.append(arg) parsed.normal_args = tuple(normal_args) parsed.ctx = ctx return parsed def get_flags(command: commands.Command) -> typing.Union[Flag, None]: flags = command.__original_kwargs__.get('flags', None) if isinstance(flags, str): try: return command.cog.__getattribute__(flags) except AttributeError: raise AttributeError(f'The flag `{flags}` does not exist') elif isinstance(flags, Flag): return flags else: return None