477 lines
18 KiB
Python
477 lines
18 KiB
Python
|
import asyncio
|
||
|
import random
|
||
|
import typing
|
||
|
from pathlib import Path
|
||
|
from threading import Timer as _Timer
|
||
|
from time import sleep
|
||
|
|
||
|
import discord
|
||
|
from discord.ext import menus
|
||
|
|
||
|
from . import logger
|
||
|
|
||
|
|
||
|
class AsyncTimer:
|
||
|
|
||
|
def __init__(self, start: float, callback, *args):
|
||
|
self._callback = callback
|
||
|
self._args = args
|
||
|
self._start = start
|
||
|
self._task = asyncio.ensure_future(self._job())
|
||
|
|
||
|
async def _job(self):
|
||
|
await asyncio.sleep(self._start)
|
||
|
await self._callback(*self._args)
|
||
|
|
||
|
def cancel(self):
|
||
|
self._task.cancel()
|
||
|
|
||
|
|
||
|
class AsyncIntervalTimer(AsyncTimer):
|
||
|
|
||
|
def __init__(self, first_start: float, interval: float, callback, *args):
|
||
|
super().__init__(first_start, callback, *args)
|
||
|
self._interval = interval
|
||
|
|
||
|
async def _job(self):
|
||
|
await super()._job()
|
||
|
while True:
|
||
|
await asyncio.sleep(self._interval)
|
||
|
await self._callback(*self._args)
|
||
|
|
||
|
def cancel(self):
|
||
|
self._task.cancel()
|
||
|
|
||
|
|
||
|
class IntervalTimer:
|
||
|
def __init__(self, first_start: float, interval: float, func, *args):
|
||
|
self.first_start = first_start
|
||
|
self.interval = interval
|
||
|
self.handlerFunction = func
|
||
|
self.args = args
|
||
|
self.running = False
|
||
|
self.timer = _Timer(self.interval, self.run, args)
|
||
|
|
||
|
def run(self, *args):
|
||
|
sleep(self.first_start)
|
||
|
self.handlerFunction(*args)
|
||
|
while self.running:
|
||
|
sleep(self.interval)
|
||
|
self.handlerFunction(*args)
|
||
|
|
||
|
def start(self):
|
||
|
self.running = True
|
||
|
self.timer.start()
|
||
|
|
||
|
def cancel(self):
|
||
|
self.running = False
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Embeds:
|
||
|
|
||
|
@staticmethod
|
||
|
async def send_kaizen_infos(channel):
|
||
|
file = discord.File(Path.cwd().joinpath('assets', 'kaizen-round.png'))
|
||
|
|
||
|
embed = discord.Embed(title='**Kaizen**', description='Folge Kaizen auf den folgenden Kanälen, um nichts mehr zu verpassen!', color=discord.Color(0xff0000))
|
||
|
embed.set_thumbnail(url='attachment://kaizen-round.png')
|
||
|
embed.add_field(name='**🎥Youtube Hauptkanal**', value='Abonniere Kaizen auf __**[Youtube](https://www.youtube.com/c/KaizenAnime)**__ um kein Anime Video mehr zu verpassen!', inline=False)
|
||
|
embed.add_field(name='**📑Youtube Toplisten-Kanal**', value='Abonniere Kaizen\'s __**[Toplisten-Kanal](https://www.youtube.com/channel/UCoijG8JqKb1rRZofx5b-LCw)**__ um kein Toplisten-Video mehr zu verpassen!', inline=False)
|
||
|
embed.add_field(name='**📯Youtube Stream-Clips & mehr**', value='Abonniere Kaizen\'s __**[Youtube Kanal](https://www.youtube.com/channel/UCodeTj8SJ-5HhJgC_Elr1Dw)**__ für Stream-Clips & mehr!', inline=False)
|
||
|
embed.add_field(name='**📲Twitch**', value='Folge Kaizen auf __**[Twitch](https://www.twitch.tv/kaizenanime)**__ und verpasse keinen Stream mehr! '
|
||
|
'Subbe Kaizen um eine exklusive Rolle auf dem Discord Server zu bekommen!', inline=False)
|
||
|
embed.add_field(name='**📢Twitter**', value='Folge Kaizen auf __**[Twitter](https://twitter.com/Kaizen_Anime)**__ um aktuelle Informationen zu bekommen und in Videos / Streams mitzuwirken!', inline=False)
|
||
|
embed.add_field(name='**📷Instagram**', value='Folge Kaizen auf __**[Instagram](https://www.instagram.com/kaizen.animeyt/)**__!', inline=False)
|
||
|
await channel.send(embed=embed, file=file)
|
||
|
|
||
|
@staticmethod
|
||
|
def error_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
|
||
|
embed = discord.Embed(color=discord.Color(0xff0000))
|
||
|
if title:
|
||
|
embed.title = title
|
||
|
if description:
|
||
|
embed.description = description
|
||
|
return embed
|
||
|
|
||
|
@staticmethod
|
||
|
def warn_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
|
||
|
embed = discord.Embed(color=discord.Color(0xff9055))
|
||
|
if title:
|
||
|
embed.title = title
|
||
|
if description:
|
||
|
embed.description = description
|
||
|
return embed
|
||
|
|
||
|
@staticmethod
|
||
|
def success_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
|
||
|
embed = discord.Embed(color=discord.Color(0x00ff00))
|
||
|
if title:
|
||
|
embed.title = title
|
||
|
if description:
|
||
|
embed.description = description
|
||
|
return embed
|
||
|
|
||
|
|
||
|
class MenuListPageSource(menus.ListPageSource):
|
||
|
|
||
|
def __init__(self, data):
|
||
|
super().__init__(data, per_page=1)
|
||
|
|
||
|
async def format_page(self, menu, embeds):
|
||
|
return embeds
|
||
|
|
||
|
|
||
|
def random_sequence_not_in_string(string: str):
|
||
|
sequence = '+'
|
||
|
while sequence in string:
|
||
|
choice = random.choice('+*~-:%&')
|
||
|
sequence = choice + sequence + choice
|
||
|
|
||
|
return sequence
|
||
|
|
||
|
|
||
|
def role_names(member: discord.Member) -> typing.List[str]:
|
||
|
return [role.name for role in member.roles]
|
||
|
|
||
|
|
||
|
# ADDED AFTERWARDS: I've stol- copied the following code from a tweepy (https://github.com/tweepy/tweepy) PR or gist (from where exactly I do not know anymore lul)
|
||
|
# at the time when they didn't support async actions
|
||
|
|
||
|
# Tweepy
|
||
|
# Copyright 2009-2021 Joshua Roesslein
|
||
|
# See LICENSE for details.
|
||
|
|
||
|
import json
|
||
|
from math import inf
|
||
|
from platform import python_version
|
||
|
|
||
|
import aiohttp
|
||
|
from oauthlib.oauth1 import Client as OAuthClient
|
||
|
from yarl import URL
|
||
|
|
||
|
import tweepy
|
||
|
from tweepy.error import TweepError
|
||
|
from tweepy.models import Status
|
||
|
|
||
|
|
||
|
class AsyncStream:
|
||
|
"""Stream realtime Tweets asynchronously
|
||
|
Parameters
|
||
|
----------
|
||
|
consumer_key: :class:`str`
|
||
|
Consumer key
|
||
|
consumer_secret: :class:`str`
|
||
|
Consuemr secret
|
||
|
access_token: :class:`str`
|
||
|
Access token
|
||
|
access_token_secret: :class:`str`
|
||
|
Access token secret
|
||
|
max_retries: Optional[:class:`int`]
|
||
|
Number of times to attempt to (re)connect the stream.
|
||
|
Defaults to infinite.
|
||
|
proxy: Optional[:class:`str`]
|
||
|
Proxy URL
|
||
|
"""
|
||
|
|
||
|
def __init__(self, consumer_key, consumer_secret, access_token,
|
||
|
access_token_secret, max_retries=inf, proxy=None):
|
||
|
self.consumer_key = consumer_key
|
||
|
self.consumer_secret = consumer_secret
|
||
|
self.access_token = access_token
|
||
|
self.access_token_secret = access_token_secret
|
||
|
self.max_retries = max_retries
|
||
|
self.proxy = proxy
|
||
|
|
||
|
self.session = None
|
||
|
self.task = None
|
||
|
self.user_agent = (
|
||
|
f"Python/{python_version()} "
|
||
|
f"aiohttp/{aiohttp.__version__} "
|
||
|
f"Tweepy/{tweepy.__version__}"
|
||
|
)
|
||
|
|
||
|
async def _connect(self, method, endpoint, params={}, headers=None,
|
||
|
body=None):
|
||
|
error_count = 0
|
||
|
# https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/connecting
|
||
|
stall_timeout = 90
|
||
|
network_error_wait = network_error_wait_step = 0.25
|
||
|
network_error_wait_max = 16
|
||
|
http_error_wait = http_error_wait_start = 5
|
||
|
http_error_wait_max = 320
|
||
|
http_420_error_wait_start = 60
|
||
|
|
||
|
oauth_client = OAuthClient(self.consumer_key, self.consumer_secret,
|
||
|
self.access_token, self.access_token_secret)
|
||
|
|
||
|
if self.session is None or self.session.closed:
|
||
|
self.session = aiohttp.ClientSession(
|
||
|
headers={"User-Agent": self.user_agent},
|
||
|
timeout=aiohttp.ClientTimeout(sock_read=stall_timeout)
|
||
|
)
|
||
|
|
||
|
url = f"https://stream.twitter.com/1.1/{endpoint}.json"
|
||
|
url = str(URL(url).with_query(sorted(params.items())))
|
||
|
|
||
|
try:
|
||
|
while error_count <= self.max_retries:
|
||
|
request_url, request_headers, request_body = oauth_client.sign(
|
||
|
url, method, body, headers
|
||
|
)
|
||
|
try:
|
||
|
async with self.session.request(
|
||
|
method, request_url, headers=request_headers,
|
||
|
data=request_body, proxy=self.proxy
|
||
|
) as resp:
|
||
|
if resp.status == 200:
|
||
|
error_count = 0
|
||
|
http_error_wait = http_error_wait_start
|
||
|
network_error_wait = network_error_wait_step
|
||
|
|
||
|
await self.on_connect()
|
||
|
|
||
|
async for line in resp.content:
|
||
|
line = line.strip()
|
||
|
if line:
|
||
|
await self.on_data(line)
|
||
|
else:
|
||
|
await self.on_keep_alive()
|
||
|
|
||
|
await self.on_closed(resp)
|
||
|
else:
|
||
|
await self.on_request_error(resp.status)
|
||
|
|
||
|
error_count += 1
|
||
|
|
||
|
if resp.status == 420:
|
||
|
if http_error_wait < http_420_error_wait_start:
|
||
|
http_error_wait = http_420_error_wait_start
|
||
|
|
||
|
await asyncio.sleep(http_error_wait)
|
||
|
|
||
|
http_error_wait *= 2
|
||
|
if resp.status != 420:
|
||
|
if http_error_wait > http_error_wait_max:
|
||
|
http_error_wait = http_error_wait_max
|
||
|
except (aiohttp.ClientConnectionError,
|
||
|
aiohttp.ClientPayloadError) as e:
|
||
|
await self.on_connection_error()
|
||
|
|
||
|
await asyncio.sleep(network_error_wait)
|
||
|
|
||
|
network_error_wait += network_error_wait_step
|
||
|
if network_error_wait > network_error_wait_max:
|
||
|
network_error_wait = network_error_wait_max
|
||
|
except asyncio.CancelledError:
|
||
|
return
|
||
|
except Exception as e:
|
||
|
await self.on_exception(e)
|
||
|
finally:
|
||
|
await self.session.close()
|
||
|
await self.on_disconnect()
|
||
|
|
||
|
async def filter(self, follow=None, track=None, locations=None,
|
||
|
stall_warnings=False):
|
||
|
"""This method is a coroutine.
|
||
|
Filter realtime Tweets
|
||
|
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/api-reference/post-statuses-filter
|
||
|
Parameters
|
||
|
----------
|
||
|
follow: Optional[List[Union[:class:`int`, :class:`str`]]]
|
||
|
A list of user IDs, indicating the users to return statuses for in
|
||
|
the stream. See https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/basic-stream-parameters
|
||
|
for more information.
|
||
|
track: Optional[List[:class:`str`]]
|
||
|
Keywords to track. Phrases of keywords are specified by a list. See
|
||
|
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||
|
for more information.
|
||
|
locations: Optional[List[:class:`float`]]
|
||
|
Specifies a set of bounding boxes to track. See
|
||
|
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||
|
for more information.
|
||
|
stall_warnings: Optional[:class:`bool`]
|
||
|
Specifies whether stall warnings should be delivered. See
|
||
|
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||
|
for more information. Def
|
||
|
logger = logging.getLogger('kaizen')aults to False.
|
||
|
Returns :class:`asyncio.Task`
|
||
|
"""
|
||
|
if self.task is not None and not self.task.done():
|
||
|
raise TweepError("Stream is already connected")
|
||
|
|
||
|
endpoint = "statuses/filter"
|
||
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||
|
|
||
|
body = {}
|
||
|
if follow is not None:
|
||
|
body["follow"] = ','.join(map(str, follow))
|
||
|
if track is not None:
|
||
|
body["track"] = ','.join(map(str, track))
|
||
|
if locations is not None:
|
||
|
if len(locations) % 4:
|
||
|
raise TweepError(
|
||
|
"Number of location coordinates should be a multiple of 4"
|
||
|
)
|
||
|
body["locations"] = ','.join(
|
||
|
f"{location:.4f}" for location in locations
|
||
|
)
|
||
|
if stall_warnings:
|
||
|
body["stall_warnings"] = "true"
|
||
|
|
||
|
self.task = asyncio.create_task(
|
||
|
self._connect("POST", endpoint, headers=headers, body=body or None)
|
||
|
)
|
||
|
return self.task
|
||
|
|
||
|
async def sample(self, stall_warnings=False):
|
||
|
"""This method is a coroutine.
|
||
|
Sample realtime Tweets
|
||
|
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/sample-realtime/api-reference/get-statuses-sample
|
||
|
Parameters
|
||
|
----------
|
||
|
stall_warnings: Optional[:class:`bool`]
|
||
|
Specifies whether stall warnings should be delivered. See
|
||
|
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||
|
for more information. Defaults to False.
|
||
|
Returns :class:`asyncio.Task`
|
||
|
"""
|
||
|
if self.task is not None and not self.task.done():
|
||
|
raise TweepError("Stream is already connected")
|
||
|
|
||
|
endpoint = "statuses/sample"
|
||
|
|
||
|
params = {}
|
||
|
if stall_warnings:
|
||
|
params["stall_warnings"] = "true"
|
||
|
|
||
|
self.task = asyncio.create_task(
|
||
|
self._connect("GET", endpoint, params=params)
|
||
|
)
|
||
|
return self.task
|
||
|
|
||
|
def disconnect(self):
|
||
|
"""Disconnect the stream"""
|
||
|
if self.task is not None:
|
||
|
self.task.cancel()
|
||
|
|
||
|
async def on_closed(self, resp):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when the stream has been closed by Twitter.
|
||
|
"""
|
||
|
logger.error("Stream connection closed by Twitter")
|
||
|
|
||
|
async def on_connect(self):
|
||
|
"""This method is a coroutine.
|
||
|
This is called after successfully connecting to the streaming API.
|
||
|
"""
|
||
|
# logger.info("Stream connected")
|
||
|
|
||
|
async def on_connection_error(self):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when the stream connection errors or times out.
|
||
|
"""
|
||
|
# logger.error("Stream connection has errored or timed out")
|
||
|
|
||
|
async def on_disconnect(self):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when the stream has disconnected.
|
||
|
"""
|
||
|
# logger.info("Stream disconnected")
|
||
|
|
||
|
async def on_exception(self, exception):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when an unhandled exception occurs.
|
||
|
"""
|
||
|
logger.exception("Stream encountered an exception")
|
||
|
|
||
|
async def on_keep_alive(self):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a keep-alive message is received.
|
||
|
"""
|
||
|
#logger.debug("Received keep-alive message")
|
||
|
|
||
|
async def on_request_error(self, status_code):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a non-200 HTTP status code is encountered.
|
||
|
"""
|
||
|
# logger.error("Stream encountered HTTP Error: %d", status_code)
|
||
|
|
||
|
async def on_data(self, raw_data):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when raw data is received from the stream.
|
||
|
This method handles sending the data to other methods, depending on the
|
||
|
message type.
|
||
|
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/streaming-message-types
|
||
|
"""
|
||
|
data = json.loads(raw_data)
|
||
|
|
||
|
if "in_reply_to_status_id" in data:
|
||
|
status = Status.parse(None, data)
|
||
|
return await self.on_status(status)
|
||
|
if "delete" in data:
|
||
|
delete = data["delete"]["status"]
|
||
|
return await self.on_delete(delete["id"], delete["user_id"])
|
||
|
if "disconnect" in data:
|
||
|
return await self.on_disconnect_message(data["disconnect"])
|
||
|
if "limit" in data:
|
||
|
return await self.on_limit(data["limit"]["track"])
|
||
|
if "scrub_geo" in data:
|
||
|
return await self.on_scrub_geo(data["scrub_geo"])
|
||
|
if "status_withheld" in data:
|
||
|
return await self.on_status_withheld(data["status_withheld"])
|
||
|
if "user_withheld" in data:
|
||
|
return await self.on_user_withheld(data["user_withheld"])
|
||
|
if "warning" in data:
|
||
|
return await self.on_warning(data["warning"])
|
||
|
|
||
|
logger.warning("Received unknown message type: %s", raw_data)
|
||
|
|
||
|
async def on_status(self, status):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a status is received.
|
||
|
"""
|
||
|
# logger.debug("Received status: %d", status.id)
|
||
|
|
||
|
async def on_delete(self, status_id, user_id):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a status deletion notice is received.
|
||
|
"""
|
||
|
# logger.debug("Received status deletion notice: %d", status_id)
|
||
|
|
||
|
async def on_disconnect_message(self, message):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a disconnect message is received.
|
||
|
"""
|
||
|
# logger.warning("Received disconnect message: %s", message)
|
||
|
|
||
|
async def on_limit(self, track):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a limit notice is received.
|
||
|
"""
|
||
|
# logger.debug("Received limit notice: %d", track)
|
||
|
|
||
|
async def on_scrub_geo(self, notice):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a location deletion notice is received.
|
||
|
"""
|
||
|
# logger.debug("Received location deletion notice: %s", notice)
|
||
|
|
||
|
async def on_status_withheld(self, notice):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a status withheld content notice is received.
|
||
|
"""
|
||
|
# logger.debug("Received status withheld content notice: %s", notice)
|
||
|
|
||
|
async def on_user_withheld(self, notice):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a user withheld content notice is received.
|
||
|
"""
|
||
|
# logger.debug("Received user withheld content notice: %s", notice)
|
||
|
|
||
|
async def on_warning(self, notice):
|
||
|
"""This method is a coroutine.
|
||
|
This is called when a stall warning message is received.
|
||
|
"""
|
||
|
# logger.warning("Received stall warning: %s", notice)
|