From a331d90b5b30d44e90ab7e52feb4847f0c0e0471 Mon Sep 17 00:00:00 2001 From: "Edgar P. Burkhart" Date: Mon, 10 Jun 2024 18:12:23 +0200 Subject: [PATCH] Update AI with system prompt --- botbotbot/__main__.py | 25 +++++++++++++++++++++---- botbotbot/ai.py | 22 ++++++++++++++++++++-- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/botbotbot/__main__.py b/botbotbot/__main__.py index 88910f6..ec8bb33 100644 --- a/botbotbot/__main__.py +++ b/botbotbot/__main__.py @@ -16,7 +16,14 @@ with open("wordlist.pickle", "rb") as word_file: guild_ids = config.get("guild_ids") delay = config.get("delay", 60) -aibot = AIBot(config.get("mistral_api_key"), model="open-mixtral-8x7b") +system_prompt = """Tu es une intelligence artificelle qui répond en français. +Ta réponse doit être très courte. +Ta réponse doit être longue d'une phrase.""" +aibot = AIBot( + config.get("mistral_api_key"), + model="open-mixtral-8x7b", + system_message=system_prompt, +) intents = discord.Intents.default() intents.members = True @@ -163,13 +170,23 @@ async def alea(ctx): ) async def indu(ctx, prompt): await ctx.defer() - answer = aibot.answer(prompt) + res_stream = aibot.get_response_stream(prompt) + embed = discord.Embed( title=prompt, - description=answer, + description="", thumbnail="https://mistral.ai/images/favicon/favicon-32x32.png", + color=discord.Colour.orange(), ) - await ctx.respond(embed=embed) + message = await ctx.respond(embed=embed) + + async for chunk in res_stream: + if chunk.choices[0].delta.content is not None: + embed.description += chunk.choices[0].delta.content + await message.edit(embed=embed) + + embed.color = None + await message.edit(embed=embed) @bot.slash_command( diff --git a/botbotbot/ai.py b/botbotbot/ai.py index e1de820..764596b 100644 --- a/botbotbot/ai.py +++ b/botbotbot/ai.py @@ -1,19 +1,37 @@ +from mistralai.async_client import MistralAsyncClient from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage class AIBot: - def __init__(self, api_key, model="open-mistral-7b", max_tokens=None): + def __init__( + self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None + ): self.client = MistralClient(api_key=api_key) + self.async_client = MistralAsyncClient(api_key=api_key) self.model = model self.max_tokens = max_tokens + self.system_message = system_message def get_responses(self, message): return self.client.chat( model=self.model, - messages=[ChatMessage(role="user", content=message)], + messages=self.base_message + [ChatMessage(role="user", content=message)], max_tokens=self.max_tokens, ) def answer(self, message): return self.get_responses(message).choices[0].message.content + + def get_response_stream(self, message): + return self.async_client.chat_stream( + model=self.model, + messages=self.base_message + [ChatMessage(role="user", content=message)], + max_tokens=self.max_tokens, + ) + + @property + def base_message(self): + if self.system_message: + return [ChatMessage(role="system", content=self.system_message)] + return []