@@ -1,7 +1,8 @@
use anyhow::{Context as _, Result, anyhow};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
+use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
@@ -12,11 +13,14 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, RateLimiter, Role,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+ RateLimiter, Role, StopReason,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr;
use std::sync::Arc;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, prelude::*};
@@ -28,6 +32,13 @@ const PROVIDER_ID: &str = "deepseek";
const PROVIDER_NAME: &str = "DeepSeek";
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+}
+
#[derive(Default, Clone, Debug, PartialEq)]
pub struct DeepSeekSettings {
pub api_url: String,
@@ -280,11 +291,11 @@ impl LanguageModel for DeepSeekLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
+ true
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
- false
+ true
}
fn supports_images(&self) -> bool {
@@ -339,35 +350,12 @@ impl LanguageModel for DeepSeekLanguageModel {
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
- let request = into_deepseek(
- request,
- self.model.id().to_string(),
- self.max_output_tokens(),
- );
+ let request = into_deepseek(request, &self.model, self.max_output_tokens());
let stream = self.stream_completion(request, cx);
async move {
- let stream = stream.await?;
- Ok(stream
- .map(|result| {
- result
- .and_then(|response| {
- response
- .choices
- .first()
- .context("Empty response")
- .map(|choice| {
- choice
- .delta
- .content
- .clone()
- .unwrap_or_default()
- .map(LanguageModelCompletionEvent::Text)
- })
- })
- .map_err(LanguageModelCompletionError::Other)
- })
- .boxed())
+ let mapper = DeepSeekEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
}
.boxed()
}
@@ -375,69 +363,67 @@ impl LanguageModel for DeepSeekLanguageModel {
pub fn into_deepseek(
request: LanguageModelRequest,
- model: String,
+ model: &deepseek::Model,
max_output_tokens: Option<u32>,
) -> deepseek::Request {
- let is_reasoner = model == "deepseek-reasoner";
-
- let len = request.messages.len();
- let merged_messages =
- request
- .messages
- .into_iter()
- .fold(Vec::with_capacity(len), |mut acc, msg| {
- let role = msg.role;
- let content = msg.string_contents();
-
- if is_reasoner {
- if let Some(last_msg) = acc.last_mut() {
- match (last_msg, role) {
- (deepseek::RequestMessage::User { content: last }, Role::User) => {
- last.push(' ');
- last.push_str(&content);
- return acc;
- }
-
- (
- deepseek::RequestMessage::Assistant {
- content: last_content,
- ..
- },
- Role::Assistant,
- ) => {
- *last_content = last_content
- .take()
- .map(|c| {
- let mut s =
- String::with_capacity(c.len() + content.len() + 1);
- s.push_str(&c);
- s.push(' ');
- s.push_str(&content);
- s
- })
- .or(Some(content));
-
- return acc;
- }
- _ => {}
- }
+ let is_reasoner = *model == deepseek::Model::Reasoner;
+
+ let mut messages = Vec::new();
+ for message in request.messages {
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
+ .push(match message.role {
+ Role::User => deepseek::RequestMessage::User { content: text },
+ Role::Assistant => deepseek::RequestMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ },
+ Role::System => deepseek::RequestMessage::System { content: text },
+ }),
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(_) => {}
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = deepseek::ToolCall {
+ id: tool_use.id.to_string(),
+ content: deepseek::ToolCallContent::Function {
+ function: deepseek::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
+ messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ } else {
+ messages.push(deepseek::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
}
}
-
- acc.push(match role {
- Role::User => deepseek::RequestMessage::User { content },
- Role::Assistant => deepseek::RequestMessage::Assistant {
- content: Some(content),
- tool_calls: Vec::new(),
- },
- Role::System => deepseek::RequestMessage::System { content },
- });
- acc
- });
+ MessageContent::ToolResult(tool_result) => {
+ match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ messages.push(deepseek::RequestMessage::Tool {
+ content: text.to_string(),
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ LanguageModelToolResultContent::Image(_) => {}
+ };
+ }
+ }
+ }
+ }
deepseek::Request {
- model,
- messages: merged_messages,
+ model: model.id().to_string(),
+ messages,
stream: true,
max_tokens: max_output_tokens,
temperature: if is_reasoner {
@@ -460,6 +446,103 @@ pub fn into_deepseek(
}
}
+pub struct DeepSeekEventMapper {
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+}
+
+impl DeepSeekEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: deepseek::StreamResponse,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let Some(choice) = event.choices.first() else {
+ return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ "Response contained no choices"
+ )))];
+ };
+
+ let mut events = Vec::new();
+ if let Some(content) = choice.delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+ }
+ }
+ }
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ Some("tool_calls") => {
+ events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+ match serde_json::Value::from_str(&tool_call.arguments) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ },
+ )),
+ Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ }),
+ }
+ }));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ None => {}
+ }
+
+ events
+ }
+}
+
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: Entity<State>,