Add tool support for DeepSeek (#30223)

THELOSTSOUL and Bennet Bo Fenner created

[deepseek function call
api](https://api-docs.deepseek.com/guides/function_calling)
has been released and it is same as openai.

Release Notes:

- Added tool calling support for Deepseek Models

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>

Change summary

crates/language_models/src/provider/deepseek.rs | 257 ++++++++++++------
1 file changed, 170 insertions(+), 87 deletions(-)

Detailed changes

crates/language_models/src/provider/deepseek.rs 🔗

@@ -1,7 +1,8 @@
 use anyhow::{Context as _, Result, anyhow};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
 use credentials_provider::CredentialsProvider;
 use editor::{Editor, EditorElement, EditorStyle};
+use futures::Stream;
 use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{
     AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
@@ -12,11 +13,14 @@ use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
-    LanguageModelToolChoice, RateLimiter, Role,
+    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+    RateLimiter, Role, StopReason,
 };
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr;
 use std::sync::Arc;
 use theme::ThemeSettings;
 use ui::{Icon, IconName, List, prelude::*};
@@ -28,6 +32,13 @@ const PROVIDER_ID: &str = "deepseek";
 const PROVIDER_NAME: &str = "DeepSeek";
 const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
 
+#[derive(Default)]
+struct RawToolCall {
+    id: String,
+    name: String,
+    arguments: String,
+}
+
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct DeepSeekSettings {
     pub api_url: String,
@@ -280,11 +291,11 @@ impl LanguageModel for DeepSeekLanguageModel {
     }
 
     fn supports_tools(&self) -> bool {
-        false
+        true
     }
 
     fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
-        false
+        true
     }
 
     fn supports_images(&self) -> bool {
@@ -339,35 +350,12 @@ impl LanguageModel for DeepSeekLanguageModel {
             BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
         >,
     > {
-        let request = into_deepseek(
-            request,
-            self.model.id().to_string(),
-            self.max_output_tokens(),
-        );
+        let request = into_deepseek(request, &self.model, self.max_output_tokens());
         let stream = self.stream_completion(request, cx);
 
         async move {
-            let stream = stream.await?;
-            Ok(stream
-                .map(|result| {
-                    result
-                        .and_then(|response| {
-                            response
-                                .choices
-                                .first()
-                                .context("Empty response")
-                                .map(|choice| {
-                                    choice
-                                        .delta
-                                        .content
-                                        .clone()
-                                        .unwrap_or_default()
-                                        .map(LanguageModelCompletionEvent::Text)
-                                })
-                        })
-                        .map_err(LanguageModelCompletionError::Other)
-                })
-                .boxed())
+            let mapper = DeepSeekEventMapper::new();
+            Ok(mapper.map_stream(stream.await?).boxed())
         }
         .boxed()
     }
@@ -375,69 +363,67 @@ impl LanguageModel for DeepSeekLanguageModel {
 
 pub fn into_deepseek(
     request: LanguageModelRequest,
-    model: String,
+    model: &deepseek::Model,
     max_output_tokens: Option<u32>,
 ) -> deepseek::Request {
-    let is_reasoner = model == "deepseek-reasoner";
-
-    let len = request.messages.len();
-    let merged_messages =
-        request
-            .messages
-            .into_iter()
-            .fold(Vec::with_capacity(len), |mut acc, msg| {
-                let role = msg.role;
-                let content = msg.string_contents();
-
-                if is_reasoner {
-                    if let Some(last_msg) = acc.last_mut() {
-                        match (last_msg, role) {
-                            (deepseek::RequestMessage::User { content: last }, Role::User) => {
-                                last.push(' ');
-                                last.push_str(&content);
-                                return acc;
-                            }
-
-                            (
-                                deepseek::RequestMessage::Assistant {
-                                    content: last_content,
-                                    ..
-                                },
-                                Role::Assistant,
-                            ) => {
-                                *last_content = last_content
-                                    .take()
-                                    .map(|c| {
-                                        let mut s =
-                                            String::with_capacity(c.len() + content.len() + 1);
-                                        s.push_str(&c);
-                                        s.push(' ');
-                                        s.push_str(&content);
-                                        s
-                                    })
-                                    .or(Some(content));
-
-                                return acc;
-                            }
-                            _ => {}
-                        }
+    let is_reasoner = *model == deepseek::Model::Reasoner;
+
+    let mut messages = Vec::new();
+    for message in request.messages {
+        for content in message.content {
+            match content {
+                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
+                    .push(match message.role {
+                        Role::User => deepseek::RequestMessage::User { content: text },
+                        Role::Assistant => deepseek::RequestMessage::Assistant {
+                            content: Some(text),
+                            tool_calls: Vec::new(),
+                        },
+                        Role::System => deepseek::RequestMessage::System { content: text },
+                    }),
+                MessageContent::RedactedThinking(_) => {}
+                MessageContent::Image(_) => {}
+                MessageContent::ToolUse(tool_use) => {
+                    let tool_call = deepseek::ToolCall {
+                        id: tool_use.id.to_string(),
+                        content: deepseek::ToolCallContent::Function {
+                            function: deepseek::FunctionContent {
+                                name: tool_use.name.to_string(),
+                                arguments: serde_json::to_string(&tool_use.input)
+                                    .unwrap_or_default(),
+                            },
+                        },
+                    };
+
+                    if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
+                        messages.last_mut()
+                    {
+                        tool_calls.push(tool_call);
+                    } else {
+                        messages.push(deepseek::RequestMessage::Assistant {
+                            content: None,
+                            tool_calls: vec![tool_call],
+                        });
                     }
                 }
-
-                acc.push(match role {
-                    Role::User => deepseek::RequestMessage::User { content },
-                    Role::Assistant => deepseek::RequestMessage::Assistant {
-                        content: Some(content),
-                        tool_calls: Vec::new(),
-                    },
-                    Role::System => deepseek::RequestMessage::System { content },
-                });
-                acc
-            });
+                MessageContent::ToolResult(tool_result) => {
+                    match &tool_result.content {
+                        LanguageModelToolResultContent::Text(text) => {
+                            messages.push(deepseek::RequestMessage::Tool {
+                                content: text.to_string(),
+                                tool_call_id: tool_result.tool_use_id.to_string(),
+                            });
+                        }
+                        LanguageModelToolResultContent::Image(_) => {}
+                    };
+                }
+            }
+        }
+    }
 
     deepseek::Request {
-        model,
-        messages: merged_messages,
+        model: model.id().to_string(),
+        messages,
         stream: true,
         max_tokens: max_output_tokens,
         temperature: if is_reasoner {
@@ -460,6 +446,103 @@ pub fn into_deepseek(
     }
 }
 
+pub struct DeepSeekEventMapper {
+    tool_calls_by_index: HashMap<usize, RawToolCall>,
+}
+
+impl DeepSeekEventMapper {
+    pub fn new() -> Self {
+        Self {
+            tool_calls_by_index: HashMap::default(),
+        }
+    }
+
+    pub fn map_stream(
+        mut self,
+        events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
+    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+    {
+        events.flat_map(move |event| {
+            futures::stream::iter(match event {
+                Ok(event) => self.map_event(event),
+                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+            })
+        })
+    }
+
+    pub fn map_event(
+        &mut self,
+        event: deepseek::StreamResponse,
+    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+        let Some(choice) = event.choices.first() else {
+            return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+                "Response contained no choices"
+            )))];
+        };
+
+        let mut events = Vec::new();
+        if let Some(content) = choice.delta.content.clone() {
+            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+        }
+
+        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+            for tool_call in tool_calls {
+                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+                if let Some(tool_id) = tool_call.id.clone() {
+                    entry.id = tool_id;
+                }
+
+                if let Some(function) = tool_call.function.as_ref() {
+                    if let Some(name) = function.name.clone() {
+                        entry.name = name;
+                    }
+
+                    if let Some(arguments) = function.arguments.clone() {
+                        entry.arguments.push_str(&arguments);
+                    }
+                }
+            }
+        }
+
+        match choice.finish_reason.as_deref() {
+            Some("stop") => {
+                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+            }
+            Some("tool_calls") => {
+                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+                    match serde_json::Value::from_str(&tool_call.arguments) {
+                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+                            LanguageModelToolUse {
+                                id: tool_call.id.clone().into(),
+                                name: tool_call.name.as_str().into(),
+                                is_input_complete: true,
+                                input,
+                                raw_input: tool_call.arguments.clone(),
+                            },
+                        )),
+                        Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+                            id: tool_call.id.into(),
+                            tool_name: tool_call.name.as_str().into(),
+                            raw_input: tool_call.arguments.into(),
+                            json_parse_error: error.to_string(),
+                        }),
+                    }
+                }));
+
+                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+            }
+            Some(stop_reason) => {
+                log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
+                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+            }
+            None => {}
+        }
+
+        events
+    }
+}
+
 struct ConfigurationView {
     api_key_editor: Entity<Editor>,
     state: Entity<State>,