Update .gitignore, add requests dependency, and implement CambAI TTS integration
This commit is contained in:
parent
bc9d5a8943
commit
be0567ff0f
7 changed files with 167 additions and 19 deletions
89
botbotbot/tts.py
Normal file
89
botbotbot/tts.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CambAI:
|
||||
base_url = "https://client.camb.ai/apis"
|
||||
cambai_root = pathlib.Path("cambai")
|
||||
|
||||
def __init__(self, apikey: str) -> None:
|
||||
self.apikey = apikey
|
||||
|
||||
if not self.cambai_root.is_dir():
|
||||
self.cambai_root.mkdir()
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {"x-api-key": self.apikey}
|
||||
|
||||
def tts(self, text: str) -> Any:
|
||||
if (path := self.get_path(text)).exists():
|
||||
return path
|
||||
|
||||
task_id = self.gen_task(text)
|
||||
run_id = self.get_runid(task_id)
|
||||
return self.get_iostream(text, run_id)
|
||||
|
||||
def gen_task(self, text: str) -> str | None:
|
||||
tts_payload = {
|
||||
"text": text,
|
||||
"voice_id": 20299,
|
||||
"language": 1,
|
||||
"age": 30,
|
||||
"gender": 1,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
f"{self.base_url}/tts", json=tts_payload, headers=self.headers
|
||||
)
|
||||
task_id = res.json().get("task_id")
|
||||
if not isinstance(task_id, str):
|
||||
logger.error(f"Got response {res.json()}")
|
||||
return None
|
||||
|
||||
return task_id
|
||||
|
||||
def get_runid(self, task_id: str | None) -> int | None:
|
||||
if task_id is None:
|
||||
return None
|
||||
|
||||
status = "PENDING"
|
||||
while status == "PENDING":
|
||||
res = requests.get(f"{self.base_url}/tts/{task_id}", headers=self.headers)
|
||||
status = res.json()["status"]
|
||||
print(f"Polling: {status}")
|
||||
time.sleep(1.5)
|
||||
|
||||
run_id = res.json().get("run_id")
|
||||
if not isinstance(run_id, int):
|
||||
return None
|
||||
|
||||
return run_id
|
||||
|
||||
def get_iostream(self, text: str, run_id: int | None) -> pathlib.Path | None:
|
||||
if run_id is None:
|
||||
return None
|
||||
|
||||
res = requests.get(
|
||||
f"{self.base_url}/tts-result/{run_id}", headers=self.headers, stream=True
|
||||
)
|
||||
|
||||
path = self.get_path(text)
|
||||
with open(path, "wb") as f:
|
||||
for chunk in res.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return path
|
||||
|
||||
def get_name(self, text: str) -> str:
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
def get_path(self, text: str) -> pathlib.Path:
|
||||
return self.cambai_root.joinpath(f"{self.get_name(text)}.wav")
|
||||
Loading…
Add table
Add a link
Reference in a new issue