This repository has been archived on 2023-07-12. You can view files and clone it, but cannot push or open issues or pull requests.
kaizen-bot-old/kaizenbot/utils.py
2022-04-20 23:48:59 +02:00

477 lines
18 KiB
Python

import asyncio
import random
import typing
from pathlib import Path
from threading import Timer as _Timer
from time import sleep
import discord
from discord.ext import menus
from . import logger
class AsyncTimer:
def __init__(self, start: float, callback, *args):
self._callback = callback
self._args = args
self._start = start
self._task = asyncio.ensure_future(self._job())
async def _job(self):
await asyncio.sleep(self._start)
await self._callback(*self._args)
def cancel(self):
self._task.cancel()
class AsyncIntervalTimer(AsyncTimer):
def __init__(self, first_start: float, interval: float, callback, *args):
super().__init__(first_start, callback, *args)
self._interval = interval
async def _job(self):
await super()._job()
while True:
await asyncio.sleep(self._interval)
await self._callback(*self._args)
def cancel(self):
self._task.cancel()
class IntervalTimer:
def __init__(self, first_start: float, interval: float, func, *args):
self.first_start = first_start
self.interval = interval
self.handlerFunction = func
self.args = args
self.running = False
self.timer = _Timer(self.interval, self.run, args)
def run(self, *args):
sleep(self.first_start)
self.handlerFunction(*args)
while self.running:
sleep(self.interval)
self.handlerFunction(*args)
def start(self):
self.running = True
self.timer.start()
def cancel(self):
self.running = False
pass
class Embeds:
@staticmethod
async def send_kaizen_infos(channel):
file = discord.File(Path.cwd().joinpath('assets', 'kaizen-round.png'))
embed = discord.Embed(title='**Kaizen**', description='Folge Kaizen auf den folgenden Kanälen, um nichts mehr zu verpassen!', color=discord.Color(0xff0000))
embed.set_thumbnail(url='attachment://kaizen-round.png')
embed.add_field(name='**🎥Youtube Hauptkanal**', value='Abonniere Kaizen auf __**[Youtube](https://www.youtube.com/c/KaizenAnime)**__ um kein Anime Video mehr zu verpassen!', inline=False)
embed.add_field(name='**📑Youtube Toplisten-Kanal**', value='Abonniere Kaizen\'s __**[Toplisten-Kanal](https://www.youtube.com/channel/UCoijG8JqKb1rRZofx5b-LCw)**__ um kein Toplisten-Video mehr zu verpassen!', inline=False)
embed.add_field(name='**📯Youtube Stream-Clips & mehr**', value='Abonniere Kaizen\'s __**[Youtube Kanal](https://www.youtube.com/channel/UCodeTj8SJ-5HhJgC_Elr1Dw)**__ für Stream-Clips & mehr!', inline=False)
embed.add_field(name='**📲Twitch**', value='Folge Kaizen auf __**[Twitch](https://www.twitch.tv/kaizenanime)**__ und verpasse keinen Stream mehr! '
'Subbe Kaizen um eine exklusive Rolle auf dem Discord Server zu bekommen!', inline=False)
embed.add_field(name='**📢Twitter**', value='Folge Kaizen auf __**[Twitter](https://twitter.com/Kaizen_Anime)**__ um aktuelle Informationen zu bekommen und in Videos / Streams mitzuwirken!', inline=False)
embed.add_field(name='**📷Instagram**', value='Folge Kaizen auf __**[Instagram](https://www.instagram.com/kaizen.animeyt/)**__!', inline=False)
await channel.send(embed=embed, file=file)
@staticmethod
def error_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
embed = discord.Embed(color=discord.Color(0xff0000))
if title:
embed.title = title
if description:
embed.description = description
return embed
@staticmethod
def warn_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
embed = discord.Embed(color=discord.Color(0xff9055))
if title:
embed.title = title
if description:
embed.description = description
return embed
@staticmethod
def success_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
embed = discord.Embed(color=discord.Color(0x00ff00))
if title:
embed.title = title
if description:
embed.description = description
return embed
class MenuListPageSource(menus.ListPageSource):
def __init__(self, data):
super().__init__(data, per_page=1)
async def format_page(self, menu, embeds):
return embeds
def random_sequence_not_in_string(string: str):
sequence = '+'
while sequence in string:
choice = random.choice('+*~-:%&')
sequence = choice + sequence + choice
return sequence
def role_names(member: discord.Member) -> typing.List[str]:
return [role.name for role in member.roles]
# ADDED AFTERWARDS: I've stol- copied the following code from a tweepy (https://github.com/tweepy/tweepy) PR or gist (from where exactly I do not know anymore lul)
# at the time when they didn't support async actions
# Tweepy
# Copyright 2009-2021 Joshua Roesslein
# See LICENSE for details.
import json
from math import inf
from platform import python_version
import aiohttp
from oauthlib.oauth1 import Client as OAuthClient
from yarl import URL
import tweepy
from tweepy.error import TweepError
from tweepy.models import Status
class AsyncStream:
"""Stream realtime Tweets asynchronously
Parameters
----------
consumer_key: :class:`str`
Consumer key
consumer_secret: :class:`str`
Consuemr secret
access_token: :class:`str`
Access token
access_token_secret: :class:`str`
Access token secret
max_retries: Optional[:class:`int`]
Number of times to attempt to (re)connect the stream.
Defaults to infinite.
proxy: Optional[:class:`str`]
Proxy URL
"""
def __init__(self, consumer_key, consumer_secret, access_token,
access_token_secret, max_retries=inf, proxy=None):
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.access_token = access_token
self.access_token_secret = access_token_secret
self.max_retries = max_retries
self.proxy = proxy
self.session = None
self.task = None
self.user_agent = (
f"Python/{python_version()} "
f"aiohttp/{aiohttp.__version__} "
f"Tweepy/{tweepy.__version__}"
)
async def _connect(self, method, endpoint, params={}, headers=None,
body=None):
error_count = 0
# https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/connecting
stall_timeout = 90
network_error_wait = network_error_wait_step = 0.25
network_error_wait_max = 16
http_error_wait = http_error_wait_start = 5
http_error_wait_max = 320
http_420_error_wait_start = 60
oauth_client = OAuthClient(self.consumer_key, self.consumer_secret,
self.access_token, self.access_token_secret)
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
headers={"User-Agent": self.user_agent},
timeout=aiohttp.ClientTimeout(sock_read=stall_timeout)
)
url = f"https://stream.twitter.com/1.1/{endpoint}.json"
url = str(URL(url).with_query(sorted(params.items())))
try:
while error_count <= self.max_retries:
request_url, request_headers, request_body = oauth_client.sign(
url, method, body, headers
)
try:
async with self.session.request(
method, request_url, headers=request_headers,
data=request_body, proxy=self.proxy
) as resp:
if resp.status == 200:
error_count = 0
http_error_wait = http_error_wait_start
network_error_wait = network_error_wait_step
await self.on_connect()
async for line in resp.content:
line = line.strip()
if line:
await self.on_data(line)
else:
await self.on_keep_alive()
await self.on_closed(resp)
else:
await self.on_request_error(resp.status)
error_count += 1
if resp.status == 420:
if http_error_wait < http_420_error_wait_start:
http_error_wait = http_420_error_wait_start
await asyncio.sleep(http_error_wait)
http_error_wait *= 2
if resp.status != 420:
if http_error_wait > http_error_wait_max:
http_error_wait = http_error_wait_max
except (aiohttp.ClientConnectionError,
aiohttp.ClientPayloadError) as e:
await self.on_connection_error()
await asyncio.sleep(network_error_wait)
network_error_wait += network_error_wait_step
if network_error_wait > network_error_wait_max:
network_error_wait = network_error_wait_max
except asyncio.CancelledError:
return
except Exception as e:
await self.on_exception(e)
finally:
await self.session.close()
await self.on_disconnect()
async def filter(self, follow=None, track=None, locations=None,
stall_warnings=False):
"""This method is a coroutine.
Filter realtime Tweets
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/api-reference/post-statuses-filter
Parameters
----------
follow: Optional[List[Union[:class:`int`, :class:`str`]]]
A list of user IDs, indicating the users to return statuses for in
the stream. See https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/basic-stream-parameters
for more information.
track: Optional[List[:class:`str`]]
Keywords to track. Phrases of keywords are specified by a list. See
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
for more information.
locations: Optional[List[:class:`float`]]
Specifies a set of bounding boxes to track. See
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
for more information.
stall_warnings: Optional[:class:`bool`]
Specifies whether stall warnings should be delivered. See
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
for more information. Def
logger = logging.getLogger('kaizen')aults to False.
Returns :class:`asyncio.Task`
"""
if self.task is not None and not self.task.done():
raise TweepError("Stream is already connected")
endpoint = "statuses/filter"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
body = {}
if follow is not None:
body["follow"] = ','.join(map(str, follow))
if track is not None:
body["track"] = ','.join(map(str, track))
if locations is not None:
if len(locations) % 4:
raise TweepError(
"Number of location coordinates should be a multiple of 4"
)
body["locations"] = ','.join(
f"{location:.4f}" for location in locations
)
if stall_warnings:
body["stall_warnings"] = "true"
self.task = asyncio.create_task(
self._connect("POST", endpoint, headers=headers, body=body or None)
)
return self.task
async def sample(self, stall_warnings=False):
"""This method is a coroutine.
Sample realtime Tweets
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/sample-realtime/api-reference/get-statuses-sample
Parameters
----------
stall_warnings: Optional[:class:`bool`]
Specifies whether stall warnings should be delivered. See
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
for more information. Defaults to False.
Returns :class:`asyncio.Task`
"""
if self.task is not None and not self.task.done():
raise TweepError("Stream is already connected")
endpoint = "statuses/sample"
params = {}
if stall_warnings:
params["stall_warnings"] = "true"
self.task = asyncio.create_task(
self._connect("GET", endpoint, params=params)
)
return self.task
def disconnect(self):
"""Disconnect the stream"""
if self.task is not None:
self.task.cancel()
async def on_closed(self, resp):
"""This method is a coroutine.
This is called when the stream has been closed by Twitter.
"""
logger.error("Stream connection closed by Twitter")
async def on_connect(self):
"""This method is a coroutine.
This is called after successfully connecting to the streaming API.
"""
# logger.info("Stream connected")
async def on_connection_error(self):
"""This method is a coroutine.
This is called when the stream connection errors or times out.
"""
# logger.error("Stream connection has errored or timed out")
async def on_disconnect(self):
"""This method is a coroutine.
This is called when the stream has disconnected.
"""
# logger.info("Stream disconnected")
async def on_exception(self, exception):
"""This method is a coroutine.
This is called when an unhandled exception occurs.
"""
logger.exception("Stream encountered an exception")
async def on_keep_alive(self):
"""This method is a coroutine.
This is called when a keep-alive message is received.
"""
#logger.debug("Received keep-alive message")
async def on_request_error(self, status_code):
"""This method is a coroutine.
This is called when a non-200 HTTP status code is encountered.
"""
# logger.error("Stream encountered HTTP Error: %d", status_code)
async def on_data(self, raw_data):
"""This method is a coroutine.
This is called when raw data is received from the stream.
This method handles sending the data to other methods, depending on the
message type.
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/streaming-message-types
"""
data = json.loads(raw_data)
if "in_reply_to_status_id" in data:
status = Status.parse(None, data)
return await self.on_status(status)
if "delete" in data:
delete = data["delete"]["status"]
return await self.on_delete(delete["id"], delete["user_id"])
if "disconnect" in data:
return await self.on_disconnect_message(data["disconnect"])
if "limit" in data:
return await self.on_limit(data["limit"]["track"])
if "scrub_geo" in data:
return await self.on_scrub_geo(data["scrub_geo"])
if "status_withheld" in data:
return await self.on_status_withheld(data["status_withheld"])
if "user_withheld" in data:
return await self.on_user_withheld(data["user_withheld"])
if "warning" in data:
return await self.on_warning(data["warning"])
logger.warning("Received unknown message type: %s", raw_data)
async def on_status(self, status):
"""This method is a coroutine.
This is called when a status is received.
"""
# logger.debug("Received status: %d", status.id)
async def on_delete(self, status_id, user_id):
"""This method is a coroutine.
This is called when a status deletion notice is received.
"""
# logger.debug("Received status deletion notice: %d", status_id)
async def on_disconnect_message(self, message):
"""This method is a coroutine.
This is called when a disconnect message is received.
"""
# logger.warning("Received disconnect message: %s", message)
async def on_limit(self, track):
"""This method is a coroutine.
This is called when a limit notice is received.
"""
# logger.debug("Received limit notice: %d", track)
async def on_scrub_geo(self, notice):
"""This method is a coroutine.
This is called when a location deletion notice is received.
"""
# logger.debug("Received location deletion notice: %s", notice)
async def on_status_withheld(self, notice):
"""This method is a coroutine.
This is called when a status withheld content notice is received.
"""
# logger.debug("Received status withheld content notice: %s", notice)
async def on_user_withheld(self, notice):
"""This method is a coroutine.
This is called when a user withheld content notice is received.
"""
# logger.debug("Received user withheld content notice: %s", notice)
async def on_warning(self, notice):
"""This method is a coroutine.
This is called when a stall warning message is received.
"""
# logger.warning("Received stall warning: %s", notice)