89 lines
2.4 KiB
Python
89 lines
2.4 KiB
Python
import hashlib
|
|
import logging
|
|
import pathlib
|
|
import time
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CambAI:
|
|
base_url = "https://client.camb.ai/apis"
|
|
cambai_root = pathlib.Path("cambai")
|
|
|
|
def __init__(self, apikey: str) -> None:
|
|
self.apikey = apikey
|
|
|
|
if not self.cambai_root.is_dir():
|
|
self.cambai_root.mkdir()
|
|
|
|
@property
|
|
def headers(self) -> dict[str, str]:
|
|
return {"x-api-key": self.apikey}
|
|
|
|
def tts(self, text: str) -> Any:
|
|
if (path := self.get_path(text)).exists():
|
|
return path
|
|
|
|
task_id = self.gen_task(text)
|
|
run_id = self.get_runid(task_id)
|
|
return self.get_iostream(text, run_id)
|
|
|
|
def gen_task(self, text: str) -> str | None:
|
|
tts_payload = {
|
|
"text": text,
|
|
"voice_id": 20299,
|
|
"language": 1,
|
|
"age": 30,
|
|
"gender": 1,
|
|
}
|
|
|
|
res = requests.post(
|
|
f"{self.base_url}/tts", json=tts_payload, headers=self.headers
|
|
)
|
|
task_id = res.json().get("task_id")
|
|
if not isinstance(task_id, str):
|
|
logger.error(f"Got response {res.json()}")
|
|
return None
|
|
|
|
return task_id
|
|
|
|
def get_runid(self, task_id: str | None) -> int | None:
|
|
if task_id is None:
|
|
return None
|
|
|
|
status = "PENDING"
|
|
while status == "PENDING":
|
|
res = requests.get(f"{self.base_url}/tts/{task_id}", headers=self.headers)
|
|
status = res.json()["status"]
|
|
print(f"Polling: {status}")
|
|
time.sleep(1.5)
|
|
|
|
run_id = res.json().get("run_id")
|
|
if not isinstance(run_id, int):
|
|
return None
|
|
|
|
return run_id
|
|
|
|
def get_iostream(self, text: str, run_id: int | None) -> pathlib.Path | None:
|
|
if run_id is None:
|
|
return None
|
|
|
|
res = requests.get(
|
|
f"{self.base_url}/tts-result/{run_id}", headers=self.headers, stream=True
|
|
)
|
|
|
|
path = self.get_path(text)
|
|
with open(path, "wb") as f:
|
|
for chunk in res.iter_content(chunk_size=1024):
|
|
f.write(chunk)
|
|
|
|
return path
|
|
|
|
def get_name(self, text: str) -> str:
|
|
return hashlib.sha256(text.encode()).hexdigest()
|
|
|
|
def get_path(self, text: str) -> pathlib.Path:
|
|
return self.cambai_root.joinpath(f"{self.get_name(text)}.wav")
|