173 lines
5.8 KiB
Python
173 lines
5.8 KiB
Python
import json
|
|
from typing import Optional, AsyncIterator
|
|
import httpx
|
|
from tg_bot.config.settings import settings
|
|
|
|
|
|
class DeepSeekAPIError(Exception):
|
|
pass
|
|
|
|
|
|
class DeepSeekClient:
|
|
|
|
def __init__(self, api_key: str | None = None, api_url: str | None = None):
|
|
self.api_key = api_key or settings.DEEPSEEK_API_KEY
|
|
self.api_url = api_url or settings.DEEPSEEK_API_URL
|
|
self.timeout = 60.0
|
|
|
|
def _get_headers(self) -> dict[str, str]:
|
|
if not self.api_key:
|
|
raise DeepSeekAPIError("API key not set")
|
|
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
}
|
|
|
|
async def chat_completion(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
model: str = "deepseek-chat",
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
stream: bool = False
|
|
) -> dict:
|
|
if not self.api_key:
|
|
return {
|
|
"content": "API key not configured",
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0
|
|
}
|
|
}
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"stream": stream
|
|
}
|
|
|
|
if max_tokens is not None:
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
response = await client.post(
|
|
self.api_url,
|
|
headers=self._get_headers(),
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
|
|
data = response.json()
|
|
|
|
if "choices" in data and len(data["choices"]) > 0:
|
|
content = data["choices"][0]["message"]["content"]
|
|
else:
|
|
raise DeepSeekAPIError("Invalid response format")
|
|
|
|
usage = data.get("usage", {})
|
|
|
|
return {
|
|
"content": content,
|
|
"usage": {
|
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
"completion_tokens": usage.get("completion_tokens", 0),
|
|
"total_tokens": usage.get("total_tokens", 0)
|
|
}
|
|
}
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
error_msg = f"API error: {e.response.status_code}"
|
|
try:
|
|
error_data = e.response.json()
|
|
if "error" in error_data:
|
|
error_msg = error_data['error'].get('message', error_msg)
|
|
except:
|
|
pass
|
|
raise DeepSeekAPIError(error_msg) from e
|
|
except httpx.RequestError as e:
|
|
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
|
except Exception as e:
|
|
raise DeepSeekAPIError(str(e)) from e
|
|
|
|
async def stream_chat_completion(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
model: str = "deepseek-chat",
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> AsyncIterator[str]:
|
|
if not self.api_key:
|
|
yield "API key not configured"
|
|
return
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"stream": True
|
|
}
|
|
|
|
if max_tokens is not None:
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
self.api_url,
|
|
headers=self._get_headers(),
|
|
json=payload
|
|
) as response:
|
|
response.raise_for_status()
|
|
|
|
async for line in response.aiter_lines():
|
|
if not line.strip():
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
line = line[6:]
|
|
|
|
if line.strip() == "[DONE]":
|
|
break
|
|
|
|
try:
|
|
data = json.loads(line)
|
|
|
|
if "choices" in data and len(data["choices"]) > 0:
|
|
delta = data["choices"][0].get("delta", {})
|
|
content = delta.get("content", "")
|
|
if content:
|
|
yield content
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
error_msg = f"API error: {e.response.status_code}"
|
|
try:
|
|
error_data = e.response.json()
|
|
if "error" in error_data:
|
|
error_msg = error_data['error'].get('message', error_msg)
|
|
except:
|
|
pass
|
|
raise DeepSeekAPIError(error_msg) from e
|
|
except httpx.RequestError as e:
|
|
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
|
except Exception as e:
|
|
raise DeepSeekAPIError(str(e)) from e
|
|
|
|
async def health_check(self) -> bool:
|
|
if not self.api_key:
|
|
return False
|
|
|
|
try:
|
|
test_messages = [{"role": "user", "content": "test"}]
|
|
await self.chat_completion(test_messages, max_tokens=1)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|