deepseek.rs

  1use anyhow::{Result, anyhow};
  2use futures::{
  3    AsyncBufReadExt, AsyncReadExt,
  4    io::BufReader,
  5    stream::{BoxStream, StreamExt},
  6};
  7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  8use serde::{Deserialize, Serialize};
  9use serde_json::Value;
 10use std::convert::TryFrom;
 11
 12pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com/v1";
 13
 14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 15#[serde(rename_all = "lowercase")]
 16pub enum Role {
 17    User,
 18    Assistant,
 19    System,
 20    Tool,
 21}
 22
 23impl TryFrom<String> for Role {
 24    type Error = anyhow::Error;
 25
 26    fn try_from(value: String) -> Result<Self> {
 27        match value.as_str() {
 28            "user" => Ok(Self::User),
 29            "assistant" => Ok(Self::Assistant),
 30            "system" => Ok(Self::System),
 31            "tool" => Ok(Self::Tool),
 32            _ => anyhow::bail!("invalid role '{value}'"),
 33        }
 34    }
 35}
 36
 37impl From<Role> for String {
 38    fn from(val: Role) -> Self {
 39        match val {
 40            Role::User => "user".to_owned(),
 41            Role::Assistant => "assistant".to_owned(),
 42            Role::System => "system".to_owned(),
 43            Role::Tool => "tool".to_owned(),
 44        }
 45    }
 46}
 47
 48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 50pub enum Model {
 51    #[serde(rename = "deepseek-chat")]
 52    #[default]
 53    Chat,
 54    #[serde(rename = "deepseek-reasoner")]
 55    Reasoner,
 56    #[serde(rename = "custom")]
 57    Custom {
 58        name: String,
 59        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 60        display_name: Option<String>,
 61        max_tokens: u64,
 62        max_output_tokens: Option<u64>,
 63    },
 64}
 65
 66impl Model {
 67    pub fn default_fast() -> Self {
 68        Model::Chat
 69    }
 70
 71    pub fn from_id(id: &str) -> Result<Self> {
 72        match id {
 73            "deepseek-chat" => Ok(Self::Chat),
 74            "deepseek-reasoner" => Ok(Self::Reasoner),
 75            _ => anyhow::bail!("invalid model id {id}"),
 76        }
 77    }
 78
 79    pub fn id(&self) -> &str {
 80        match self {
 81            Self::Chat => "deepseek-chat",
 82            Self::Reasoner => "deepseek-reasoner",
 83            Self::Custom { name, .. } => name,
 84        }
 85    }
 86
 87    pub fn display_name(&self) -> &str {
 88        match self {
 89            Self::Chat => "DeepSeek Chat",
 90            Self::Reasoner => "DeepSeek Reasoner",
 91            Self::Custom {
 92                name, display_name, ..
 93            } => display_name.as_ref().unwrap_or(name).as_str(),
 94        }
 95    }
 96
 97    pub fn max_token_count(&self) -> u64 {
 98        match self {
 99            Self::Chat | Self::Reasoner => 128_000,
100            Self::Custom { max_tokens, .. } => *max_tokens,
101        }
102    }
103
104    pub fn max_output_tokens(&self) -> Option<u64> {
105        match self {
106            // Their API treats this max against the context window, which means we hit the limit a lot
107            // Using the default value of None in the API instead
108            Self::Chat | Self::Reasoner => None,
109            Self::Custom {
110                max_output_tokens, ..
111            } => *max_output_tokens,
112        }
113    }
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct Request {
118    pub model: String,
119    pub messages: Vec<RequestMessage>,
120    pub stream: bool,
121    #[serde(default, skip_serializing_if = "Option::is_none")]
122    pub max_tokens: Option<u64>,
123    #[serde(default, skip_serializing_if = "Option::is_none")]
124    pub temperature: Option<f32>,
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    pub response_format: Option<ResponseFormat>,
127    #[serde(default, skip_serializing_if = "Vec::is_empty")]
128    pub tools: Vec<ToolDefinition>,
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum ResponseFormat {
134    Text,
135    #[serde(rename = "json_object")]
136    JsonObject,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140#[serde(tag = "type", rename_all = "snake_case")]
141pub enum ToolDefinition {
142    Function { function: FunctionDefinition },
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146pub struct FunctionDefinition {
147    pub name: String,
148    pub description: Option<String>,
149    pub parameters: Option<Value>,
150}
151
152#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
153#[serde(tag = "role", rename_all = "lowercase")]
154pub enum RequestMessage {
155    Assistant {
156        content: Option<String>,
157        #[serde(default, skip_serializing_if = "Vec::is_empty")]
158        tool_calls: Vec<ToolCall>,
159        #[serde(default, skip_serializing_if = "Option::is_none")]
160        reasoning_content: Option<String>,
161    },
162    User {
163        content: String,
164    },
165    System {
166        content: String,
167    },
168    Tool {
169        content: String,
170        tool_call_id: String,
171    },
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175pub struct ToolCall {
176    pub id: String,
177    #[serde(flatten)]
178    pub content: ToolCallContent,
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182#[serde(tag = "type", rename_all = "lowercase")]
183pub enum ToolCallContent {
184    Function { function: FunctionContent },
185}
186
187#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
188pub struct FunctionContent {
189    pub name: String,
190    pub arguments: String,
191}
192
193#[derive(Serialize, Deserialize, Debug)]
194pub struct Response {
195    pub id: String,
196    pub object: String,
197    pub created: u64,
198    pub model: String,
199    pub choices: Vec<Choice>,
200    pub usage: Usage,
201    #[serde(default, skip_serializing_if = "Option::is_none")]
202    pub reasoning_content: Option<String>,
203}
204
205#[derive(Serialize, Deserialize, Debug)]
206pub struct Usage {
207    pub prompt_tokens: u64,
208    pub completion_tokens: u64,
209    pub total_tokens: u64,
210    #[serde(default)]
211    pub prompt_cache_hit_tokens: u64,
212    #[serde(default)]
213    pub prompt_cache_miss_tokens: u64,
214}
215
216#[derive(Serialize, Deserialize, Debug)]
217pub struct Choice {
218    pub index: u32,
219    pub message: RequestMessage,
220    pub finish_reason: Option<String>,
221}
222
223#[derive(Serialize, Deserialize, Debug)]
224pub struct StreamResponse {
225    pub id: String,
226    pub object: String,
227    pub created: u64,
228    pub model: String,
229    pub choices: Vec<StreamChoice>,
230    pub usage: Option<Usage>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct StreamChoice {
235    pub index: u32,
236    pub delta: StreamDelta,
237    pub finish_reason: Option<String>,
238}
239
240#[derive(Serialize, Deserialize, Debug)]
241pub struct StreamDelta {
242    pub role: Option<Role>,
243    pub content: Option<String>,
244    #[serde(default, skip_serializing_if = "Option::is_none")]
245    pub tool_calls: Option<Vec<ToolCallChunk>>,
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub reasoning_content: Option<String>,
248}
249
250#[derive(Serialize, Deserialize, Debug)]
251pub struct ToolCallChunk {
252    pub index: usize,
253    pub id: Option<String>,
254    pub function: Option<FunctionChunk>,
255}
256
257#[derive(Serialize, Deserialize, Debug)]
258pub struct FunctionChunk {
259    pub name: Option<String>,
260    pub arguments: Option<String>,
261}
262
263pub async fn stream_completion(
264    client: &dyn HttpClient,
265    api_url: &str,
266    api_key: &str,
267    request: Request,
268) -> Result<BoxStream<'static, Result<StreamResponse>>> {
269    let uri = format!("{api_url}/chat/completions");
270    let request_builder = HttpRequest::builder()
271        .method(Method::POST)
272        .uri(uri)
273        .header("Content-Type", "application/json")
274        .header("Authorization", format!("Bearer {}", api_key.trim()));
275
276    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
277    let mut response = client.send(request).await?;
278
279    if response.status().is_success() {
280        let reader = BufReader::new(response.into_body());
281        Ok(reader
282            .lines()
283            .filter_map(|line| async move {
284                match line {
285                    Ok(line) => {
286                        let line = line.strip_prefix("data: ")?;
287                        if line == "[DONE]" {
288                            None
289                        } else {
290                            match serde_json::from_str(line) {
291                                Ok(response) => Some(Ok(response)),
292                                Err(error) => Some(Err(anyhow!(error))),
293                            }
294                        }
295                    }
296                    Err(error) => Some(Err(anyhow!(error))),
297                }
298            })
299            .boxed())
300    } else {
301        let mut body = String::new();
302        response.body_mut().read_to_string(&mut body).await?;
303        anyhow::bail!(
304            "Failed to connect to DeepSeek API: {} {}",
305            response.status(),
306            body,
307        );
308    }
309}