botbotbot/botbotbot/text.py

236 lines
8 KiB
Python

import logging
import random
from typing import Any, Callable, Coroutine
import discord
import emoji
from botbotbot.ai import AIBot
from botbotbot.shuffle import Shuffler
from botbotbot.wordlist import Wordlist
logger = logging.getLogger(__name__)
class TextBot:
uni_emojis = tuple(k for k in emoji.EMOJI_DATA.keys() if len(k) == 1)
def __init__(
self,
bot: discord.Bot,
wordlist: Wordlist,
aibot: AIBot | None = None,
shuffler: Shuffler | None = None,
guild_ids: list[int] = [],
rnd_weights: list[float] = [10, 5, 10],
) -> None:
self.bot = bot
self.aibot = aibot
self.wl = wordlist
self._rnd_weights = rnd_weights
self.shf = shuffler
self.guild_ids = guild_ids
def init_events(self) -> None:
self.bot.add_listener(self.on_message, "on_message")
self.bot.add_listener(self.add_more_reaction, "on_reaction_add")
self.bot.add_listener(self.add_more_reaction, "on_reaction_remove")
self.bot.add_listener(self.react_message_edit, "on_message_edit")
self.bot.add_listener(self.rando_shuffle, "on_message")
self.bot.add_application_command(
discord.SlashCommand(
self.indu,
name="indu",
guild_ids=self.guild_ids,
description="Poser une question à MistralAI",
)
)
@property
def rnd_weights(self) -> list[float]:
return self._rnd_weights + [100 - sum(self._rnd_weights)]
@property
def rnd_functions(
self,
) -> list[Callable[[discord.Message], Coroutine[Any, Any, None]]]:
return [self.reply, self.ai_reply, self.react]
@property
def rnd_functions_or_not(
self,
) -> list[Callable[[discord.Message], Coroutine[Any, Any, None]] | None]:
return [*self.rnd_functions, None]
async def on_message(self, message: discord.Message) -> None:
logger.debug(
f"Received message from <{message.author}> on channel <{message.channel}>."
)
if message.flags.ephemeral:
logger.debug("Ephemeral message, ignoring.")
return
if message.author != self.bot.user and self.bot.user in message.mentions:
logger.info(
f"Mention from <{message.author}> in channel <{message.channel}>."
)
await random.choices(self.rnd_functions, weights=self.rnd_weights[:-1])[0](
message
)
return
func = random.choices(self.rnd_functions_or_not, weights=self.rnd_weights)[0]
if func is None:
logger.debug("No action.")
return
await func(message)
async def reply(self, message: discord.Message) -> None:
logger.info(f"Replying to <{message.author}> in channel <{message.channel}>.")
mention = random.choices(
[f"<@{message.author.id}>", "@everyone", "@here"], weights=(97, 1, 2)
)[0]
content = random.choice(
(
f"{mention}, {self.wl.random()}",
f"{self.wl.random()}",
)
)
if (
isinstance(message.channel, discord.TextChannel)
and random.random() < 10 / 100
):
await self.send_as_webhook(
message.channel,
message.author.display_name,
message.author.avatar.url if message.author.avatar else None,
content,
)
else:
fct = random.choice((message.reply, message.channel.send))
await fct(content)
async def ai_reply(self, message: discord.Message) -> None:
if self.aibot is None:
logger.debug("No AI bot, ignoring.")
return
logger.info(f"AI Reply to {message.author}")
prompt = message.clean_content
if prompt == "" and message.embeds and message.embeds[0].description:
prompt = message.embeds[0].description
answer = self.aibot.answer(prompt)
if not isinstance(answer, str):
logger.error(f"Got unexpected result from AIBot : {answer}")
return
if len(answer) > 2000:
logger.debug("Answer too long, sending as embed.")
embed = discord.Embed(
description=answer,
thumbnail="https://mistral.ai/images/favicon/favicon-32x32.png",
)
await message.reply(embed=embed)
else:
await message.reply(answer)
async def react(self, message: discord.Message) -> None:
logger.info(
f"React to message from <{message.author}> in channel <{message.channel}>."
)
emojis: tuple[str | discord.Emoji, ...] = self.uni_emojis
if message.guild is not None and random.random() < 50 / 100:
emojis = message.guild.emojis
emo: str | discord.Emoji = random.choice(emojis)
await message.add_reaction(emo)
async def send_as_webhook(
self,
channel: discord.TextChannel,
name: str,
avatar_url: str | None,
content: str,
) -> None:
webhooks = await channel.webhooks()
webhook = discord.utils.get(webhooks, name="BotbotbotHook")
if webhook is None:
webhook = await channel.create_webhook(name="BotbotbotHook")
await webhook.send(content=content, username=name, avatar_url=avatar_url)
async def add_more_reaction(
self, reaction: discord.Reaction, user: discord.Member | discord.User
) -> None:
if user == self.bot.user:
return
message = reaction.message
guild = message.guild
for reaction in message.reactions:
if random.random() < 50 / 100:
if random.random() < 50 / 100:
if (
self.bot.user is not None
and self.bot.user in await reaction.users().flatten()
and guild is not None
and isinstance(
member := guild.get_member(self.bot.user.id),
discord.Member,
)
):
logger.info(f"Remove reaction <{reaction}>.")
await message.remove_reaction(reaction, member)
else:
logger.info(f"Copy reaction <{reaction}> from <{user}>.")
await reaction.message.add_reaction(reaction.emoji)
if random.random() < 10 / 100:
await self.react(message)
async def react_message_edit(
self, before: discord.Message, after: discord.Message
) -> None:
if after.author != self.bot.user and before.content != after.content:
logger.info(f"React to edit from {after.author}.")
await after.add_reaction("👀")
async def rando_shuffle(self, message: discord.Message) -> None:
if (
self.shf
and not message.flags.ephemeral
and random.random() < 5 / 100
and message.guild
):
logger.info(f"Message shuffle after message from {message.author}")
await self.shf.try_shuffle(message.guild)
async def indu(self, ctx: discord.ApplicationContext, prompt: str) -> None:
if self.aibot is None:
return
logger.info(f"INDU {ctx.author} {prompt}")
await ctx.defer()
res_stream = await self.aibot.get_response_stream(prompt)
embed = discord.Embed(
title=prompt,
description="",
thumbnail="https://mistral.ai/images/favicon/favicon-32x32.png",
color=discord.Colour.orange(),
)
message = await ctx.respond(embed=embed)
async for chunk in res_stream:
if chunk.data.choices[0].delta.content is not None:
embed.description += chunk.data.choices[0].delta.content
await message.edit(embed=embed)
embed.colour = None
await message.edit(embed=embed)
logger.info("FIN INDU")