1use anyhow::{Result, anyhow};
2use futures::{
3 AsyncBufReadExt, AsyncReadExt,
4 io::BufReader,
5 stream::{BoxStream, StreamExt},
6};
7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::convert::TryFrom;
11
12pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com";
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => Err(anyhow!("invalid role '{value}'")),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub enum Model {
51 #[serde(rename = "deepseek-chat")]
52 #[default]
53 Chat,
54 #[serde(rename = "deepseek-reasoner")]
55 Reasoner,
56 #[serde(rename = "custom")]
57 Custom {
58 name: String,
59 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
60 display_name: Option<String>,
61 max_tokens: usize,
62 max_output_tokens: Option<u32>,
63 },
64}
65
66impl Model {
67 pub fn from_id(id: &str) -> Result<Self> {
68 match id {
69 "deepseek-chat" => Ok(Self::Chat),
70 "deepseek-reasoner" => Ok(Self::Reasoner),
71 _ => Err(anyhow!("invalid model id")),
72 }
73 }
74
75 pub fn id(&self) -> &str {
76 match self {
77 Self::Chat => "deepseek-chat",
78 Self::Reasoner => "deepseek-reasoner",
79 Self::Custom { name, .. } => name,
80 }
81 }
82
83 pub fn display_name(&self) -> &str {
84 match self {
85 Self::Chat => "DeepSeek Chat",
86 Self::Reasoner => "DeepSeek Reasoner",
87 Self::Custom {
88 name, display_name, ..
89 } => display_name.as_ref().unwrap_or(name).as_str(),
90 }
91 }
92
93 pub fn max_token_count(&self) -> usize {
94 match self {
95 Self::Chat | Self::Reasoner => 64_000,
96 Self::Custom { max_tokens, .. } => *max_tokens,
97 }
98 }
99
100 pub fn max_output_tokens(&self) -> Option<u32> {
101 match self {
102 Self::Chat => Some(8_192),
103 Self::Reasoner => Some(8_192),
104 Self::Custom {
105 max_output_tokens, ..
106 } => *max_output_tokens,
107 }
108 }
109}
110
111#[derive(Debug, Serialize, Deserialize)]
112pub struct Request {
113 pub model: String,
114 pub messages: Vec<RequestMessage>,
115 pub stream: bool,
116 #[serde(default, skip_serializing_if = "Option::is_none")]
117 pub max_tokens: Option<u32>,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub temperature: Option<f32>,
120 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub response_format: Option<ResponseFormat>,
122 #[serde(default, skip_serializing_if = "Vec::is_empty")]
123 pub tools: Vec<ToolDefinition>,
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127#[serde(rename_all = "snake_case")]
128pub enum ResponseFormat {
129 Text,
130 #[serde(rename = "json_object")]
131 JsonObject,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(tag = "type", rename_all = "snake_case")]
136pub enum ToolDefinition {
137 Function { function: FunctionDefinition },
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141pub struct FunctionDefinition {
142 pub name: String,
143 pub description: Option<String>,
144 pub parameters: Option<Value>,
145}
146
147#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
148#[serde(tag = "role", rename_all = "lowercase")]
149pub enum RequestMessage {
150 Assistant {
151 content: Option<String>,
152 #[serde(default, skip_serializing_if = "Vec::is_empty")]
153 tool_calls: Vec<ToolCall>,
154 },
155 User {
156 content: String,
157 },
158 System {
159 content: String,
160 },
161 Tool {
162 content: String,
163 tool_call_id: String,
164 },
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168pub struct ToolCall {
169 pub id: String,
170 #[serde(flatten)]
171 pub content: ToolCallContent,
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175#[serde(tag = "type", rename_all = "lowercase")]
176pub enum ToolCallContent {
177 Function { function: FunctionContent },
178}
179
180#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
181pub struct FunctionContent {
182 pub name: String,
183 pub arguments: String,
184}
185
186#[derive(Serialize, Deserialize, Debug)]
187pub struct Response {
188 pub id: String,
189 pub object: String,
190 pub created: u64,
191 pub model: String,
192 pub choices: Vec<Choice>,
193 pub usage: Usage,
194 #[serde(default, skip_serializing_if = "Option::is_none")]
195 pub reasoning_content: Option<String>,
196}
197
198#[derive(Serialize, Deserialize, Debug)]
199pub struct Usage {
200 pub prompt_tokens: u32,
201 pub completion_tokens: u32,
202 pub total_tokens: u32,
203 #[serde(default)]
204 pub prompt_cache_hit_tokens: u32,
205 #[serde(default)]
206 pub prompt_cache_miss_tokens: u32,
207}
208
209#[derive(Serialize, Deserialize, Debug)]
210pub struct Choice {
211 pub index: u32,
212 pub message: RequestMessage,
213 pub finish_reason: Option<String>,
214}
215
216#[derive(Serialize, Deserialize, Debug)]
217pub struct StreamResponse {
218 pub id: String,
219 pub object: String,
220 pub created: u64,
221 pub model: String,
222 pub choices: Vec<StreamChoice>,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct StreamChoice {
227 pub index: u32,
228 pub delta: StreamDelta,
229 pub finish_reason: Option<String>,
230}
231
232#[derive(Serialize, Deserialize, Debug)]
233pub struct StreamDelta {
234 pub role: Option<Role>,
235 pub content: Option<String>,
236 #[serde(default, skip_serializing_if = "Option::is_none")]
237 pub tool_calls: Option<Vec<ToolCallChunk>>,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub reasoning_content: Option<String>,
240}
241
242#[derive(Serialize, Deserialize, Debug)]
243pub struct ToolCallChunk {
244 pub index: usize,
245 pub id: Option<String>,
246 pub function: Option<FunctionChunk>,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct FunctionChunk {
251 pub name: Option<String>,
252 pub arguments: Option<String>,
253}
254
255pub async fn stream_completion(
256 client: &dyn HttpClient,
257 api_url: &str,
258 api_key: &str,
259 request: Request,
260) -> Result<BoxStream<'static, Result<StreamResponse>>> {
261 let uri = format!("{api_url}/v1/chat/completions");
262 let request_builder = HttpRequest::builder()
263 .method(Method::POST)
264 .uri(uri)
265 .header("Content-Type", "application/json")
266 .header("Authorization", format!("Bearer {}", api_key));
267
268 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
269 let mut response = client.send(request).await?;
270
271 if response.status().is_success() {
272 let reader = BufReader::new(response.into_body());
273 Ok(reader
274 .lines()
275 .filter_map(|line| async move {
276 match line {
277 Ok(line) => {
278 let line = line.strip_prefix("data: ")?;
279 if line == "[DONE]" {
280 None
281 } else {
282 match serde_json::from_str(line) {
283 Ok(response) => Some(Ok(response)),
284 Err(error) => Some(Err(anyhow!(error))),
285 }
286 }
287 }
288 Err(error) => Some(Err(anyhow!(error))),
289 }
290 })
291 .boxed())
292 } else {
293 let mut body = String::new();
294 response.body_mut().read_to_string(&mut body).await?;
295 Err(anyhow!(
296 "Failed to connect to DeepSeek API: {} {}",
297 response.status(),
298 body,
299 ))
300 }
301}