This repository has been archived on 2023-07-12. You can view files and clone it, but cannot push or open issues or pull requests.
kaizen-bot-old/kaizenbot/flags.py

296 lines
11 KiB
Python
Raw Normal View History

2022-04-20 23:41:25 +02:00
from copy import copy
from discord.ext import commands
from enum import Enum
import shlex
import typing
from .utils import Embeds, role_names
class AlternativeStoreType:
class _StoreType:
pass
class Bool(_StoreType):
pass
class Int(_StoreType):
def __init__(self, min: int = float('-inf'), max: int = float('inf')):
if min > max:
raise ValueError('min must be higher than max')
self.min = min
self.max = max
class Value(_StoreType):
pass
store_bool = Bool()
store_int = Int()
store_value = Value()
class StoreType(Enum):
store_bool = 'store_bool'
store_value = 'store_value'
class _Parsed:
def __init__(self):
self.normal_args = ()
self.ctx: commands.Context = None
class FlagErrors:
class _FlagError(Exception):
def __init__(self, message: str = None, flag: str = None):
self.flag = flag
super().__init__(message)
class DoubleFlagError(_FlagError):
pass
class FlagParseError(_FlagError):
pass
class FlagPositionError(_FlagError):
def __init__(self, message: str = None, flag: str = None):
super().__init__(message, flag)
class FlagStoreError(_FlagError):
def __init__(self, message: str = None, flag: str = None, store_type: StoreType = None, value=None):
self.store_type = store_type
self.value = value
super().__init__(message, flag)
class RoleNotAllowedError(_FlagError):
def __init__(self, message: str = None, flag: str = None, role: str = None):
self.role = role
super().__init__(message, flag)
class QuotationError(Exception):
def __init__(self, message: str = None):
super().__init__(message)
class _ParsedFlag:
def __init__(self, flag: str, value):
self.flag = flag
self._value = value
def __new__(cls, flag: str, value):
return value
def __add__(self, other):
return self._value + other
def __bool__(self):
return True if self._value else False
def __ge__(self, other):
return self._value >= other
def __getitem__(self, item):
return self._value
def __gt__(self, other):
return self._value < other
def __hash__(self):
return hash((self.flag, self._value))
def __le__(self, other):
return self._value <= other
def __len__(self):
return len(self._value)
def __lt__(self, other):
return self._value < other
def __repr__(self):
return self._value
class Flag:
def __init__(self, raise_errors=True,
double_flag_error='The flag `{flag}` was already given',
flag_position_error='Flags must be called at the end of a command',
quotation_error='At least one quotation mark is missing (`"` or `\'`)'):
self.double_flag_error = double_flag_error
self.raise_errors = raise_errors
self.flag_position_error = flag_position_error
self.quotation_error = quotation_error
self._flags = {}
def add_flag(self, *flags,
store_type: typing.Union[StoreType, typing.Tuple[StoreType, typing.Any]],
parser: typing.Callable = None,
allowed_roles: typing.Union[typing.List[str], typing.Tuple[str]] = [],
disallowed_roles: typing.Union[typing.List[str], typing.Tuple[str], typing.Dict[str, str]] = [],
not_allowed_role_message='User has no allowed role for flag {flag}',
wrong_store_type_message=None,
help: str = '', show_help=True):
if store_type == StoreType.store_bool and parser is not None:
raise FlagErrors.FlagParseError('The flag parser cannot be set if the store type is \'store_bool\'')
if store_type == StoreType.store_bool:
default_value = False
else:
default_value = None
for allowed in allowed_roles:
if allowed in disallowed_roles:
raise ValueError(f'Role `{allowed}` cannot be allowed and disallowed at the same time')
flag_information = {'store_type': store_type[0] if isinstance(store_type, tuple) else store_type, 'default': store_type[1] if isinstance(store_type, tuple) else default_value,
'parser': parser,
'allowed_roles': allowed_roles, 'disallowed_roles': disallowed_roles,
'not_allowed_role_message': not_allowed_role_message,
'wrong_store_type_message': wrong_store_type_message,
'help': help, 'show_help': show_help}
for flag in flags:
flag = str(flag)
self._flags[flag] = flag_information
def flags(self) -> typing.Dict[str, typing.Any]:
return copy(self._flags)
async def parse(self, args: str, ctx: commands.Context):
# (re)sets the attributes every time
parsed = _Parsed()
try:
shlex_args = shlex.split(args)
except ValueError as error:
if str(error) == 'No closing quotation':
if self.raise_errors:
raise FlagErrors.QuotationError(self.quotation_error)
else:
await ctx.send(embed=Embeds.error_embed(description=self.quotation_error))
return
else:
raise error
for flag, information in self._flags.items():
parsed.__setattr__(flag[2:], information['default'])
flag_indexed = False
normal_args = []
parsed_flags = []
roles = role_names(ctx.author)
for i, arg in enumerate(shlex_args):
arg = str(arg).replace('"', '').replace("'", '')
if '=' in arg:
arg, value = arg.split('=', 1)
else:
value = None
if arg in self._flags:
if arg in parsed_flags:
if self.raise_errors:
raise FlagErrors.DoubleFlagError(self.double_flag_error.format(flag=arg))
else:
await ctx.send(embed=Embeds.error_embed(description=self.double_flag_error.format(flag=arg)))
return
else:
parsed_flags.append(arg)
if not flag_indexed:
flag_indexed = True
flag = self._flags[arg]
# --- #
allowed_roles = flag['allowed_roles']
if allowed_roles:
if not any(allowed in roles for allowed in allowed_roles):
error = flag['not_allowed_role_message'].format(flag=arg)
if self.raise_errors:
raise FlagErrors.RoleNotAllowedError(error)
else:
await ctx.send(embed=Embeds.error_embed(description=error))
return
disallowed_roles = flag['disallowed_roles']
if disallowed_roles:
for disallowed in disallowed_roles:
if disallowed in roles:
error = disallowed_roles[disallowed] if isinstance(disallowed_roles, dict) else 'The role `{role}` is not allowed to use the {flag} flag'
error = error.format(role=disallowed, flag=arg)
if self.raise_errors:
raise FlagErrors.RoleNotAllowedError(message=error, flag=arg, role=disallowed)
else:
await ctx.send(embed=Embeds.error_embed(description=error))
return
store_type = flag['store_type']
arg_without_prefix = arg[2:]
if store_type == StoreType.store_bool:
error = flag['wrong_store_type_message'] if flag['wrong_store_type_message'] else 'Flag `{flag}` must not contain a value'
error = error.format(flag=arg)
if value:
if self.raise_errors:
raise FlagErrors.FlagStoreError(message=error, flag=arg, store_type=store_type, value=value)
else:
await ctx.send(embed=Embeds.error_embed(description=error))
return
else:
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, True))
elif store_type == StoreType.store_value:
error = flag['wrong_store_type_message'] if flag['wrong_store_type_message'] else 'Flag `{flag}` must not contain a value'
error = error.format(flag=arg)
if not value:
if self.raise_errors:
raise FlagErrors.FlagStoreError(message=error, flag=arg, store_type=store_type)
else:
await ctx.send(embed=Embeds.error_embed(description=error))
return
else:
if parser := flag['parser']:
value_parsed = await parser(ctx, arg, value)
if isinstance(value_parsed, bool):
if value_parsed:
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, True))
else:
return
else:
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, value_parsed))
else:
parsed.__setattr__(arg_without_prefix, _ParsedFlag(flag, value))
elif flag_indexed:
if self.raise_errors:
raise FlagErrors.FlagPositionError(message=self.flag_position_error)
else:
await ctx.send(embed=Embeds.error_embed(description=self.flag_position_error))
return
else:
normal_args.append(arg)
parsed.normal_args = tuple(normal_args)
parsed.ctx = ctx
return parsed
def get_flags(command: commands.Command) -> typing.Union[Flag, None]:
flags = command.__original_kwargs__.get('flags', None)
if isinstance(flags, str):
try:
return command.cog.__getattribute__(flags)
except AttributeError:
raise AttributeError(f'The flag `{flags}` does not exist')
elif isinstance(flags, Flag):
return flags
else:
return None