deepseek.rs

  1use anyhow::{Result, anyhow};
  2use futures::{
  3    AsyncBufReadExt, AsyncReadExt,
  4    io::BufReader,
  5    stream::{BoxStream, StreamExt},
  6};
  7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  8use serde::{Deserialize, Serialize};
  9use serde_json::Value;
 10use std::convert::TryFrom;
 11
 12pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com";
 13
 14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 15#[serde(rename_all = "lowercase")]
 16pub enum Role {
 17    User,
 18    Assistant,
 19    System,
 20    Tool,
 21}
 22
 23impl TryFrom<String> for Role {
 24    type Error = anyhow::Error;
 25
 26    fn try_from(value: String) -> Result<Self> {
 27        match value.as_str() {
 28            "user" => Ok(Self::User),
 29            "assistant" => Ok(Self::Assistant),
 30            "system" => Ok(Self::System),
 31            "tool" => Ok(Self::Tool),
 32            _ => Err(anyhow!("invalid role '{value}'")),
 33        }
 34    }
 35}
 36
 37impl From<Role> for String {
 38    fn from(val: Role) -> Self {
 39        match val {
 40            Role::User => "user".to_owned(),
 41            Role::Assistant => "assistant".to_owned(),
 42            Role::System => "system".to_owned(),
 43            Role::Tool => "tool".to_owned(),
 44        }
 45    }
 46}
 47
 48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 50pub enum Model {
 51    #[serde(rename = "deepseek-chat")]
 52    #[default]
 53    Chat,
 54    #[serde(rename = "deepseek-reasoner")]
 55    Reasoner,
 56    #[serde(rename = "custom")]
 57    Custom {
 58        name: String,
 59        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 60        display_name: Option<String>,
 61        max_tokens: usize,
 62        max_output_tokens: Option<u32>,
 63    },
 64}
 65
 66impl Model {
 67    pub fn from_id(id: &str) -> Result<Self> {
 68        match id {
 69            "deepseek-chat" => Ok(Self::Chat),
 70            "deepseek-reasoner" => Ok(Self::Reasoner),
 71            _ => Err(anyhow!("invalid model id")),
 72        }
 73    }
 74
 75    pub fn id(&self) -> &str {
 76        match self {
 77            Self::Chat => "deepseek-chat",
 78            Self::Reasoner => "deepseek-reasoner",
 79            Self::Custom { name, .. } => name,
 80        }
 81    }
 82
 83    pub fn display_name(&self) -> &str {
 84        match self {
 85            Self::Chat => "DeepSeek Chat",
 86            Self::Reasoner => "DeepSeek Reasoner",
 87            Self::Custom {
 88                name, display_name, ..
 89            } => display_name.as_ref().unwrap_or(name).as_str(),
 90        }
 91    }
 92
 93    pub fn max_token_count(&self) -> usize {
 94        match self {
 95            Self::Chat | Self::Reasoner => 64_000,
 96            Self::Custom { max_tokens, .. } => *max_tokens,
 97        }
 98    }
 99
100    pub fn max_output_tokens(&self) -> Option<u32> {
101        match self {
102            Self::Chat => Some(8_192),
103            Self::Reasoner => Some(8_192),
104            Self::Custom {
105                max_output_tokens, ..
106            } => *max_output_tokens,
107        }
108    }
109}
110
111#[derive(Debug, Serialize, Deserialize)]
112pub struct Request {
113    pub model: String,
114    pub messages: Vec<RequestMessage>,
115    pub stream: bool,
116    #[serde(default, skip_serializing_if = "Option::is_none")]
117    pub max_tokens: Option<u32>,
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub temperature: Option<f32>,
120    #[serde(default, skip_serializing_if = "Option::is_none")]
121    pub response_format: Option<ResponseFormat>,
122    #[serde(default, skip_serializing_if = "Vec::is_empty")]
123    pub tools: Vec<ToolDefinition>,
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127#[serde(rename_all = "snake_case")]
128pub enum ResponseFormat {
129    Text,
130    #[serde(rename = "json_object")]
131    JsonObject,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(tag = "type", rename_all = "snake_case")]
136pub enum ToolDefinition {
137    Function { function: FunctionDefinition },
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141pub struct FunctionDefinition {
142    pub name: String,
143    pub description: Option<String>,
144    pub parameters: Option<Value>,
145}
146
147#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
148#[serde(tag = "role", rename_all = "lowercase")]
149pub enum RequestMessage {
150    Assistant {
151        content: Option<String>,
152        #[serde(default, skip_serializing_if = "Vec::is_empty")]
153        tool_calls: Vec<ToolCall>,
154    },
155    User {
156        content: String,
157    },
158    System {
159        content: String,
160    },
161    Tool {
162        content: String,
163        tool_call_id: String,
164    },
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168pub struct ToolCall {
169    pub id: String,
170    #[serde(flatten)]
171    pub content: ToolCallContent,
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175#[serde(tag = "type", rename_all = "lowercase")]
176pub enum ToolCallContent {
177    Function { function: FunctionContent },
178}
179
180#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
181pub struct FunctionContent {
182    pub name: String,
183    pub arguments: String,
184}
185
186#[derive(Serialize, Deserialize, Debug)]
187pub struct Response {
188    pub id: String,
189    pub object: String,
190    pub created: u64,
191    pub model: String,
192    pub choices: Vec<Choice>,
193    pub usage: Usage,
194    #[serde(default, skip_serializing_if = "Option::is_none")]
195    pub reasoning_content: Option<String>,
196}
197
198#[derive(Serialize, Deserialize, Debug)]
199pub struct Usage {
200    pub prompt_tokens: u32,
201    pub completion_tokens: u32,
202    pub total_tokens: u32,
203    #[serde(default)]
204    pub prompt_cache_hit_tokens: u32,
205    #[serde(default)]
206    pub prompt_cache_miss_tokens: u32,
207}
208
209#[derive(Serialize, Deserialize, Debug)]
210pub struct Choice {
211    pub index: u32,
212    pub message: RequestMessage,
213    pub finish_reason: Option<String>,
214}
215
216#[derive(Serialize, Deserialize, Debug)]
217pub struct StreamResponse {
218    pub id: String,
219    pub object: String,
220    pub created: u64,
221    pub model: String,
222    pub choices: Vec<StreamChoice>,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct StreamChoice {
227    pub index: u32,
228    pub delta: StreamDelta,
229    pub finish_reason: Option<String>,
230}
231
232#[derive(Serialize, Deserialize, Debug)]
233pub struct StreamDelta {
234    pub role: Option<Role>,
235    pub content: Option<String>,
236    #[serde(default, skip_serializing_if = "Option::is_none")]
237    pub tool_calls: Option<Vec<ToolCallChunk>>,
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub reasoning_content: Option<String>,
240}
241
242#[derive(Serialize, Deserialize, Debug)]
243pub struct ToolCallChunk {
244    pub index: usize,
245    pub id: Option<String>,
246    pub function: Option<FunctionChunk>,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct FunctionChunk {
251    pub name: Option<String>,
252    pub arguments: Option<String>,
253}
254
255pub async fn stream_completion(
256    client: &dyn HttpClient,
257    api_url: &str,
258    api_key: &str,
259    request: Request,
260) -> Result<BoxStream<'static, Result<StreamResponse>>> {
261    let uri = format!("{api_url}/v1/chat/completions");
262    let request_builder = HttpRequest::builder()
263        .method(Method::POST)
264        .uri(uri)
265        .header("Content-Type", "application/json")
266        .header("Authorization", format!("Bearer {}", api_key));
267
268    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
269    let mut response = client.send(request).await?;
270
271    if response.status().is_success() {
272        let reader = BufReader::new(response.into_body());
273        Ok(reader
274            .lines()
275            .filter_map(|line| async move {
276                match line {
277                    Ok(line) => {
278                        let line = line.strip_prefix("data: ")?;
279                        if line == "[DONE]" {
280                            None
281                        } else {
282                            match serde_json::from_str(line) {
283                                Ok(response) => Some(Ok(response)),
284                                Err(error) => Some(Err(anyhow!(error))),
285                            }
286                        }
287                    }
288                    Err(error) => Some(Err(anyhow!(error))),
289                }
290            })
291            .boxed())
292    } else {
293        let mut body = String::new();
294        response.body_mut().read_to_string(&mut body).await?;
295        Err(anyhow!(
296            "Failed to connect to DeepSeek API: {} {}",
297            response.status(),
298            body,
299        ))
300    }
301}