1use anyhow::{Result, anyhow};
2use futures::{
3 AsyncBufReadExt, AsyncReadExt,
4 io::BufReader,
5 stream::{BoxStream, StreamExt},
6};
7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::convert::TryFrom;
11
12pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com/v1";
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => anyhow::bail!("invalid role '{value}'"),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub enum Model {
51 #[serde(rename = "deepseek-chat")]
52 #[default]
53 Chat,
54 #[serde(rename = "deepseek-reasoner")]
55 Reasoner,
56 #[serde(rename = "custom")]
57 Custom {
58 name: String,
59 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
60 display_name: Option<String>,
61 max_tokens: u64,
62 max_output_tokens: Option<u64>,
63 },
64}
65
66impl Model {
67 pub fn default_fast() -> Self {
68 Model::Chat
69 }
70
71 pub fn from_id(id: &str) -> Result<Self> {
72 match id {
73 "deepseek-chat" => Ok(Self::Chat),
74 "deepseek-reasoner" => Ok(Self::Reasoner),
75 _ => anyhow::bail!("invalid model id {id}"),
76 }
77 }
78
79 pub fn id(&self) -> &str {
80 match self {
81 Self::Chat => "deepseek-chat",
82 Self::Reasoner => "deepseek-reasoner",
83 Self::Custom { name, .. } => name,
84 }
85 }
86
87 pub fn display_name(&self) -> &str {
88 match self {
89 Self::Chat => "DeepSeek Chat",
90 Self::Reasoner => "DeepSeek Reasoner",
91 Self::Custom {
92 name, display_name, ..
93 } => display_name.as_ref().unwrap_or(name).as_str(),
94 }
95 }
96
97 pub fn max_token_count(&self) -> u64 {
98 match self {
99 Self::Chat | Self::Reasoner => 128_000,
100 Self::Custom { max_tokens, .. } => *max_tokens,
101 }
102 }
103
104 pub fn max_output_tokens(&self) -> Option<u64> {
105 match self {
106 // Their API treats this max against the context window, which means we hit the limit a lot
107 // Using the default value of None in the API instead
108 Self::Chat | Self::Reasoner => None,
109 Self::Custom {
110 max_output_tokens, ..
111 } => *max_output_tokens,
112 }
113 }
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct Request {
118 pub model: String,
119 pub messages: Vec<RequestMessage>,
120 pub stream: bool,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub max_tokens: Option<u64>,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub temperature: Option<f32>,
125 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub response_format: Option<ResponseFormat>,
127 #[serde(default, skip_serializing_if = "Vec::is_empty")]
128 pub tools: Vec<ToolDefinition>,
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum ResponseFormat {
134 Text,
135 #[serde(rename = "json_object")]
136 JsonObject,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140#[serde(tag = "type", rename_all = "snake_case")]
141pub enum ToolDefinition {
142 Function { function: FunctionDefinition },
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146pub struct FunctionDefinition {
147 pub name: String,
148 pub description: Option<String>,
149 pub parameters: Option<Value>,
150}
151
152#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
153#[serde(tag = "role", rename_all = "lowercase")]
154pub enum RequestMessage {
155 Assistant {
156 content: Option<String>,
157 #[serde(default, skip_serializing_if = "Vec::is_empty")]
158 tool_calls: Vec<ToolCall>,
159 #[serde(default, skip_serializing_if = "Option::is_none")]
160 reasoning_content: Option<String>,
161 },
162 User {
163 content: String,
164 },
165 System {
166 content: String,
167 },
168 Tool {
169 content: String,
170 tool_call_id: String,
171 },
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175pub struct ToolCall {
176 pub id: String,
177 #[serde(flatten)]
178 pub content: ToolCallContent,
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182#[serde(tag = "type", rename_all = "lowercase")]
183pub enum ToolCallContent {
184 Function { function: FunctionContent },
185}
186
187#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
188pub struct FunctionContent {
189 pub name: String,
190 pub arguments: String,
191}
192
193#[derive(Serialize, Deserialize, Debug)]
194pub struct Response {
195 pub id: String,
196 pub object: String,
197 pub created: u64,
198 pub model: String,
199 pub choices: Vec<Choice>,
200 pub usage: Usage,
201 #[serde(default, skip_serializing_if = "Option::is_none")]
202 pub reasoning_content: Option<String>,
203}
204
205#[derive(Serialize, Deserialize, Debug)]
206pub struct Usage {
207 pub prompt_tokens: u64,
208 pub completion_tokens: u64,
209 pub total_tokens: u64,
210 #[serde(default)]
211 pub prompt_cache_hit_tokens: u64,
212 #[serde(default)]
213 pub prompt_cache_miss_tokens: u64,
214}
215
216#[derive(Serialize, Deserialize, Debug)]
217pub struct Choice {
218 pub index: u32,
219 pub message: RequestMessage,
220 pub finish_reason: Option<String>,
221}
222
223#[derive(Serialize, Deserialize, Debug)]
224pub struct StreamResponse {
225 pub id: String,
226 pub object: String,
227 pub created: u64,
228 pub model: String,
229 pub choices: Vec<StreamChoice>,
230 pub usage: Option<Usage>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct StreamChoice {
235 pub index: u32,
236 pub delta: StreamDelta,
237 pub finish_reason: Option<String>,
238}
239
240#[derive(Serialize, Deserialize, Debug)]
241pub struct StreamDelta {
242 pub role: Option<Role>,
243 pub content: Option<String>,
244 #[serde(default, skip_serializing_if = "Option::is_none")]
245 pub tool_calls: Option<Vec<ToolCallChunk>>,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub reasoning_content: Option<String>,
248}
249
250#[derive(Serialize, Deserialize, Debug)]
251pub struct ToolCallChunk {
252 pub index: usize,
253 pub id: Option<String>,
254 pub function: Option<FunctionChunk>,
255}
256
257#[derive(Serialize, Deserialize, Debug)]
258pub struct FunctionChunk {
259 pub name: Option<String>,
260 pub arguments: Option<String>,
261}
262
263pub async fn stream_completion(
264 client: &dyn HttpClient,
265 api_url: &str,
266 api_key: &str,
267 request: Request,
268) -> Result<BoxStream<'static, Result<StreamResponse>>> {
269 let uri = format!("{api_url}/chat/completions");
270 let request_builder = HttpRequest::builder()
271 .method(Method::POST)
272 .uri(uri)
273 .header("Content-Type", "application/json")
274 .header("Authorization", format!("Bearer {}", api_key.trim()));
275
276 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
277 let mut response = client.send(request).await?;
278
279 if response.status().is_success() {
280 let reader = BufReader::new(response.into_body());
281 Ok(reader
282 .lines()
283 .filter_map(|line| async move {
284 match line {
285 Ok(line) => {
286 let line = line.strip_prefix("data: ")?;
287 if line == "[DONE]" {
288 None
289 } else {
290 match serde_json::from_str(line) {
291 Ok(response) => Some(Ok(response)),
292 Err(error) => Some(Err(anyhow!(error))),
293 }
294 }
295 }
296 Err(error) => Some(Err(anyhow!(error))),
297 }
298 })
299 .boxed())
300 } else {
301 let mut body = String::new();
302 response.body_mut().read_to_string(&mut body).await?;
303 anyhow::bail!(
304 "Failed to connect to DeepSeek API: {} {}",
305 response.status(),
306 body,
307 );
308 }
309}