Update .gitignore, add requests dependency, and implement CambAI TTS integration
This commit is contained in:
parent
bc9d5a8943
commit
be0567ff0f
7 changed files with 167 additions and 19 deletions
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
|
@ -7,6 +6,7 @@ import tomllib
|
|||
import discord
|
||||
|
||||
from botbotbot.ai import AIBot
|
||||
from botbotbot.tts import CambAI
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
@ -38,6 +38,10 @@ def main() -> None:
|
|||
system_message=system_prompt,
|
||||
)
|
||||
|
||||
cambai: CambAI | None = None
|
||||
if isinstance(key := config.get("cambai_api_key"), str):
|
||||
cambai = CambAI(key)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.members = True
|
||||
intents.message_content = True
|
||||
|
@ -235,7 +239,7 @@ def main() -> None:
|
|||
logger.info("ERRE ALEA")
|
||||
|
||||
@bot.listen("on_voice_state_update")
|
||||
async def voice_random_nicks(
|
||||
async def on_voice_state_update(
|
||||
member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
|
||||
) -> None:
|
||||
if before.channel is None and random.random() < 5 / 100:
|
||||
|
@ -247,24 +251,29 @@ def main() -> None:
|
|||
logger.debug(after.channel)
|
||||
if after.channel:
|
||||
logger.debug(after.channel.members)
|
||||
if (
|
||||
before.channel is None
|
||||
and after.channel is not None
|
||||
and random.random() < 5 / 100
|
||||
and bot not in after.channel.members
|
||||
):
|
||||
logger.info(f"Voice connect from {member}")
|
||||
source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
|
||||
|
||||
await asyncio.sleep(random.randrange(60))
|
||||
if (
|
||||
cambai is not None
|
||||
and before.channel is None
|
||||
and after.channel is not None
|
||||
and bot not in after.channel.members
|
||||
and bot.user
|
||||
and member.id != bot.user.id
|
||||
and random.random() < 5 / 100
|
||||
):
|
||||
logger.info("Generating tts")
|
||||
script = random.choice(
|
||||
[
|
||||
"Salut la jeunesse !",
|
||||
f"Salut {member.display_name}, ça va bien ?",
|
||||
"Allo ? À l'huile !",
|
||||
]
|
||||
)
|
||||
source = await discord.FFmpegOpusAudio.from_probe(cambai.tts(script))
|
||||
vo: discord.VoiceClient = await after.channel.connect()
|
||||
|
||||
await asyncio.sleep(random.randrange(10))
|
||||
await vo.play(source, wait_finish=True)
|
||||
|
||||
await asyncio.sleep(random.randrange(60))
|
||||
await vo.disconnect()
|
||||
logger.info("Voice disconnect")
|
||||
|
||||
@bot.slash_command(
|
||||
name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
|
||||
|
|
89
botbotbot/tts.py
Normal file
89
botbotbot/tts.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CambAI:
|
||||
base_url = "https://client.camb.ai/apis"
|
||||
cambai_root = pathlib.Path("cambai")
|
||||
|
||||
def __init__(self, apikey: str) -> None:
|
||||
self.apikey = apikey
|
||||
|
||||
if not self.cambai_root.is_dir():
|
||||
self.cambai_root.mkdir()
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {"x-api-key": self.apikey}
|
||||
|
||||
def tts(self, text: str) -> Any:
|
||||
if (path := self.get_path(text)).exists():
|
||||
return path
|
||||
|
||||
task_id = self.gen_task(text)
|
||||
run_id = self.get_runid(task_id)
|
||||
return self.get_iostream(text, run_id)
|
||||
|
||||
def gen_task(self, text: str) -> str | None:
|
||||
tts_payload = {
|
||||
"text": text,
|
||||
"voice_id": 20299,
|
||||
"language": 1,
|
||||
"age": 30,
|
||||
"gender": 1,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
f"{self.base_url}/tts", json=tts_payload, headers=self.headers
|
||||
)
|
||||
task_id = res.json().get("task_id")
|
||||
if not isinstance(task_id, str):
|
||||
logger.error(f"Got response {res.json()}")
|
||||
return None
|
||||
|
||||
return task_id
|
||||
|
||||
def get_runid(self, task_id: str | None) -> int | None:
|
||||
if task_id is None:
|
||||
return None
|
||||
|
||||
status = "PENDING"
|
||||
while status == "PENDING":
|
||||
res = requests.get(f"{self.base_url}/tts/{task_id}", headers=self.headers)
|
||||
status = res.json()["status"]
|
||||
print(f"Polling: {status}")
|
||||
time.sleep(1.5)
|
||||
|
||||
run_id = res.json().get("run_id")
|
||||
if not isinstance(run_id, int):
|
||||
return None
|
||||
|
||||
return run_id
|
||||
|
||||
def get_iostream(self, text: str, run_id: int | None) -> pathlib.Path | None:
|
||||
if run_id is None:
|
||||
return None
|
||||
|
||||
res = requests.get(
|
||||
f"{self.base_url}/tts-result/{run_id}", headers=self.headers, stream=True
|
||||
)
|
||||
|
||||
path = self.get_path(text)
|
||||
with open(path, "wb") as f:
|
||||
for chunk in res.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return path
|
||||
|
||||
def get_name(self, text: str) -> str:
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
def get_path(self, text: str) -> pathlib.Path:
|
||||
return self.cambai_root.joinpath(f"{self.get_name(text)}.wav")
|
Loading…
Add table
Add a link
Reference in a new issue