diff --git a/Cargo.lock b/Cargo.lock index 2ad617854d70fd43578e0ad8c0784eaf8c638923..55abe28d70609213ad175705faa2adec7bab08e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -559,6 +559,7 @@ version = "0.1.0" dependencies = [ "anthropic", "anyhow", + "deepseek", "feature_flags", "fs", "gpui", @@ -3684,6 +3685,18 @@ dependencies = [ "winapi", ] +[[package]] +name = "deepseek" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "deflate64" version = "0.1.9" @@ -6809,6 +6822,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "collections", + "deepseek", "futures 0.3.31", "google_ai", "gpui", @@ -6852,6 +6866,7 @@ dependencies = [ "client", "collections", "copilot", + "deepseek", "editor", "feature_flags", "fs", diff --git a/Cargo.toml b/Cargo.toml index e7481808e22e99c7a74483b3803e7a4e3c7a8918..9f25df9a0f723c06c61fd0649ae2febecfc2d1d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ members = [ "crates/copilot", "crates/db", "crates/diagnostics", + "crates/deepseek", "crates/docs_preprocessor", "crates/editor", "crates/evals", @@ -229,6 +230,7 @@ context_server = { path = "crates/context_server" } context_server_settings = { path = "crates/context_server_settings" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } +deepseek = { path = "crates/deepseek" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } extension = { path = "crates/extension" } diff --git a/assets/icons/ai_deep_seek.svg b/assets/icons/ai_deep_seek.svg new file mode 100644 index 0000000000000000000000000000000000000000..cf480c834c9f01d914c6fe37885903cdb79ff27f --- /dev/null +++ b/assets/icons/ai_deep_seek.svg @@ -0,0 +1 @@ +DeepSeek diff --git a/assets/settings/default.json b/assets/settings/default.json index 04b9bdc29e0bee786d881cc327edd6b0c9f65dcc..ad982a7179c664bbf048f875634edee7234baee3 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1166,6 +1166,9 @@ }, "lmstudio": { "api_url": "http://localhost:1234/api/v0" + }, + "deepseek": { + "api_url": "https://api.deepseek.com" } }, // Zed's Prettier integration settings. diff --git a/crates/assistant_settings/Cargo.toml b/crates/assistant_settings/Cargo.toml index 32ebb6a9593fdc5d67bcc4201536306b1f23e722..ee7aa5f5ba8fd4ff1c586d3bc371202923d3b7c4 100644 --- a/crates/assistant_settings/Cargo.toml +++ b/crates/assistant_settings/Cargo.toml @@ -14,6 +14,7 @@ path = "src/assistant_settings.rs" [dependencies] anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true +deepseek = { workspace = true, features = ["schemars"] } feature_flags.workspace = true gpui.workspace = true language_model.workspace = true diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index c98182b24d1cd62e7ff0d20266ec3cbe07619d94..62fc95ead2d69a6a7317f5c9881793f95383c0b0 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use ::open_ai::Model as OpenAiModel; use anthropic::Model as AnthropicModel; +use deepseek::Model as DeepseekModel; use feature_flags::FeatureFlagAppExt; use gpui::{AppContext, Pixels}; use language_model::{CloudModel, LanguageModel}; @@ -46,6 +47,11 @@ pub enum AssistantProviderContentV1 { default_model: Option, api_url: Option, }, + #[serde(rename = "deepseek")] + DeepSeek { + default_model: Option, + api_url: Option, + }, } #[derive(Debug, Default)] @@ -149,6 +155,12 @@ impl AssistantSettingsContent { model: model.id().to_string(), }) } + AssistantProviderContentV1::DeepSeek { default_model, .. } => { + default_model.map(|model| LanguageModelSelection { + provider: "deepseek".to_string(), + model: model.id().to_string(), + }) + } }), inline_alternatives: None, enable_experimental_live_diffs: None, @@ -253,6 +265,18 @@ impl AssistantSettingsContent { available_models, }); } + "deepseek" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::DeepSeek { + default_model: DeepseekModel::from_id(&model).ok(), + api_url, + }); + } _ => {} }, VersionedAssistantSettingsContent::V2(settings) => { @@ -341,6 +365,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: "openai".into(), "zed.dev".into(), "copilot_chat".into(), + "deepseek".into(), ]), ..Default::default() } @@ -380,7 +405,7 @@ pub struct AssistantSettingsContentV1 { default_height: Option, /// The provider of the assistant service. /// - /// This can be "openai", "anthropic", "ollama", "lmstudio", "zed.dev" + /// This can be "openai", "anthropic", "ollama", "lmstudio", "deepseek", "zed.dev" /// each with their respective default models and configurations. provider: Option, } diff --git a/crates/deepseek/Cargo.toml b/crates/deepseek/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..0993285fb7d67533fa4c145fea6616d2d8fa8f24 --- /dev/null +++ b/crates/deepseek/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "deepseek" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/deepseek.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/deepseek/LICENSE-GPL b/crates/deepseek/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/deepseek/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs new file mode 100644 index 0000000000000000000000000000000000000000..777cf696d8e047eef9685ff410ae396c10f4cff5 --- /dev/null +++ b/crates/deepseek/src/deepseek.rs @@ -0,0 +1,301 @@ +use anyhow::{anyhow, Result}; +use futures::{ + io::BufReader, + stream::{BoxStream, StreamExt}, + AsyncBufReadExt, AsyncReadExt, +}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::convert::TryFrom; + +pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum Model { + #[serde(rename = "deepseek-chat")] + #[default] + Chat, + #[serde(rename = "deepseek-reasoner")] + Reasoner, + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_tokens: usize, + max_output_tokens: Option, + }, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + match id { + "deepseek-chat" => Ok(Self::Chat), + "deepseek-reasoner" => Ok(Self::Reasoner), + _ => Err(anyhow!("invalid model id")), + } + } + + pub fn id(&self) -> &str { + match self { + Self::Chat => "deepseek-chat", + Self::Reasoner => "deepseek-reasoner", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Chat => "DeepSeek Chat", + Self::Reasoner => "DeepSeek Reasoner", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name).as_str(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Chat | Self::Reasoner => 64_000, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::Chat => Some(8_192), + Self::Reasoner => Some(8_192), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormat { + Text, + #[serde(rename = "json_object")] + JsonObject, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + Function { function: FunctionDefinition }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum RequestMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(default)] + pub prompt_cache_hit_tokens: u32, + #[serde(default)] + pub prompt_cache_miss_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamChoice { + pub index: u32, + pub delta: StreamDelta, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamDelta { + pub role: Option, + pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/v1/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(error))), + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Err(anyhow!( + "Failed to connect to DeepSeek API: {} {}", + response.status(), + body, + )) + } +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 0842e18752b36ceed5b9ab70bcbecd6dd623bbe0..74505b1780fd545fd3abde9d823bf2c278241212 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -29,6 +29,7 @@ log.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } lmstudio = { workspace = true, features = ["schemars"] } +deepseek = { workspace = true, features = ["schemars"] } parking_lot.workspace = true proto.workspace = true schemars.workspace = true diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 513f51f2868180fa08da5c03c845d2ea5cd72872..c4ad4ba264f41b9805d0ed2408bf91f7e832df93 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -410,6 +410,84 @@ impl LanguageModelRequest { top_p: None, } } + + pub fn into_deepseek(self, model: String, max_output_tokens: Option) -> deepseek::Request { + let is_reasoner = model == "deepseek-reasoner"; + + let len = self.messages.len(); + let merged_messages = + self.messages + .into_iter() + .fold(Vec::with_capacity(len), |mut acc, msg| { + let role = msg.role; + let content = msg.string_contents(); + + if is_reasoner { + if let Some(last_msg) = acc.last_mut() { + match (last_msg, role) { + (deepseek::RequestMessage::User { content: last }, Role::User) => { + last.push(' '); + last.push_str(&content); + return acc; + } + + ( + deepseek::RequestMessage::Assistant { + content: last_content, + .. + }, + Role::Assistant, + ) => { + *last_content = last_content + .take() + .map(|c| { + let mut s = + String::with_capacity(c.len() + content.len() + 1); + s.push_str(&c); + s.push(' '); + s.push_str(&content); + s + }) + .or(Some(content)); + + return acc; + } + _ => {} + } + } + } + + acc.push(match role { + Role::User => deepseek::RequestMessage::User { content }, + Role::Assistant => deepseek::RequestMessage::Assistant { + content: Some(content), + tool_calls: Vec::new(), + }, + Role::System => deepseek::RequestMessage::System { content }, + }); + acc + }); + + deepseek::Request { + model, + messages: merged_messages, + stream: true, + max_tokens: max_output_tokens, + temperature: if is_reasoner { None } else { self.temperature }, + response_format: None, + tools: self + .tools + .into_iter() + .map(|tool| deepseek::ToolDefinition::Function { + function: deepseek::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + } + } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index 17366fa3ec594fccf59ee768236d383cc61e0949..fa56a2a88ba71c33705663e29dbe9796b63c6962 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -66,6 +66,16 @@ impl From for open_ai::Role { } } +impl From for deepseek::Role { + fn from(val: Role) -> Self { + match val { + Role::User => deepseek::Role::User, + Role::Assistant => deepseek::Role::Assistant, + Role::System => deepseek::Role::System, + } + } +} + impl From for lmstudio::Role { fn from(val: Role) -> Self { match val { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b66447124bdb4977e863dfb3d4c31f3ed44b1bcb..4d7590e40e031cdf6f45ad19d54c8a735ab73034 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -29,6 +29,7 @@ menu.workspace = true ollama = { workspace = true, features = ["schemars"] } lmstudio = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +deepseek = { workspace = true, features = ["schemars"] } project.workspace = true proto.workspace = true schemars.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 604c1a7ce4e675e22b10e136dd3a8f97a0615a10..17addf6a5f29c48847c89bc4fac281e1c8c41a84 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -4,6 +4,7 @@ use client::{Client, UserStore}; use fs::Fs; use gpui::{AppContext, Model, ModelContext}; use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use provider::deepseek::DeepSeekLanguageModelProvider; mod logging; pub mod provider; @@ -60,6 +61,10 @@ fn register_language_model_providers( LmStudioLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + DeepSeekLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider( GoogleLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 09fb975fc673b04ee6de95cc0d9806fb7bf7ef52..a7738563e709a8a49b57ad8cf656122895c157e9 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -1,6 +1,7 @@ pub mod anthropic; pub mod cloud; pub mod copilot_chat; +pub mod deepseek; pub mod google; pub mod lmstudio; pub mod ollama; diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs new file mode 100644 index 0000000000000000000000000000000000000000..8e2dbe940ce6d8e204894cdbf347b8437878c4d1 --- /dev/null +++ b/crates/language_models/src/provider/deepseek.rs @@ -0,0 +1,558 @@ +use anyhow::{anyhow, Result}; +use collections::BTreeMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle, + View, WhiteSpace, +}; +use http_client::HttpClient; +use language_model::{ + LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use theme::ThemeSettings; +use ui::{prelude::*, Icon, IconName}; +use util::ResultExt; + +use crate::AllLanguageModelSettings; + +const PROVIDER_ID: &str = "deepseek"; +const PROVIDER_NAME: &str = "DeepSeek"; +const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct DeepSeekSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub max_output_tokens: Option, +} + +pub struct DeepSeekLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + _subscription: Subscription, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut ModelContext) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).deepseek; + let delete_credentials = cx.delete_credentials(&settings.api_url); + cx.spawn(|this, mut cx| async move { + delete_credentials.await.log_err(); + this.update(&mut cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).deepseek; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); + + cx.spawn(|this, mut cx| async move { + write_credentials.await?; + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut ModelContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + let api_url = AllLanguageModelSettings::get_global(cx) + .deepseek + .api_url + .clone(); + + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + (String::from_utf8(api_key)?, false) + }; + + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + }) + }) + } + } +} + +impl DeepSeekLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_this: &mut State, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for DeepSeekLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for DeepSeekLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiDeepSeek + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = BTreeMap::default(); + + models.insert("deepseek-chat", deepseek::Model::Chat); + models.insert("deepseek-reasoner", deepseek::Model::Reasoner); + + for available_model in AllLanguageModelSettings::get_global(cx) + .deepseek + .available_models + .iter() + { + models.insert( + &available_model.name, + deepseek::Model::Custom { + name: available_model.name.clone(), + display_name: available_model.display_name.clone(), + max_tokens: available_model.max_tokens, + max_output_tokens: available_model.max_output_tokens, + }, + ); + } + + models + .into_values() + .map(|model| { + Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct DeepSeekLanguageModel { + id: LanguageModelId, + model: deepseek::Model, + state: gpui::Model, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl DeepSeekLanguageModel { + fn stream_completion( + &self, + request: deepseek::Request, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).deepseek; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.ok_or_else(|| anyhow!("Missing DeepSeek API Key"))?; + let request = + deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for DeepSeekLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("deepseek/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + cx.background_executor() + .spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + }) + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens()); + let stream = self.stream_completion(request, cx); + + async move { + let stream = stream.await?; + Ok(stream + .map(|result| { + result.and_then(|response| { + response + .choices + .first() + .ok_or_else(|| anyhow!("Empty response")) + .map(|choice| { + choice + .delta + .content + .clone() + .unwrap_or_default() + .map(LanguageModelCompletionEvent::Text) + }) + }) + }) + .boxed()) + } + .boxed() + } + fn use_any_tool( + &self, + request: LanguageModelRequest, + name: String, + description: String, + schema: serde_json::Value, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let mut deepseek_request = + request.into_deepseek(self.model.id().to_string(), self.max_output_tokens()); + + deepseek_request.tools = vec![deepseek::ToolDefinition::Function { + function: deepseek::FunctionDefinition { + name: name.clone(), + description: Some(description), + parameters: Some(schema), + }, + }]; + + let response_stream = self.stream_completion(deepseek_request, cx); + + self.request_limiter + .run(async move { + let stream = response_stream.await?; + + let tool_args_stream = stream + .filter_map(move |response| async move { + match response { + Ok(response) => { + for choice in response.choices { + if let Some(tool_calls) = choice.delta.tool_calls { + for tool_call in tool_calls { + if let Some(function) = tool_call.function { + if let Some(args) = function.arguments { + return Some(Ok(args)); + } + } + } + } + } + None + } + Err(e) => Some(Err(e)), + } + }) + .boxed(); + + Ok(tool_args_stream) + }) + .boxed() + } +} + +struct ConfigurationView { + api_key_editor: View, + state: gpui::Model, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Model, cx: &mut ViewContext) -> Self { + let api_key_editor = cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text("sk-00000000000000000000000000000000", cx); + editor + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn({ + let state = state.clone(); + |this, mut cx| async move { + if let Some(task) = state + .update(&mut cx, |state, cx| state.authenticate(cx)) + .log_err() + { + let _ = task.await; + } + + this.update(&mut cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, cx: &mut ViewContext) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", cx)); + + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.reset_api_key(cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + truncate: None, + }; + EditorElement::new( + &self.api_key_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn should_render_editor(&self, cx: &mut ViewContext) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const DEEPSEEK_CONSOLE_URL: &str = "https://platform.deepseek.com/api_keys"; + const INSTRUCTIONS: [&str; 3] = [ + "To use DeepSeek in Zed, you need an API key:", + "- Get your API key from:", + "- Paste it below and press enter:", + ]; + + let env_var_set = self.state.read(cx).api_key_from_env; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new(INSTRUCTIONS[0])) + .child( + h_flex().child(Label::new(INSTRUCTIONS[1])).child( + Button::new("deepseek_console", DEEPSEEK_CONSOLE_URL) + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, cx| cx.open_url(DEEPSEEK_CONSOLE_URL)), + ), + ) + .child(Label::new(INSTRUCTIONS[2])) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new(format!( + "Or set {} environment variable", + DEEPSEEK_API_KEY_VAR + )) + .size(LabelSize::Small), + ) + .into_any() + } else { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {}", DEEPSEEK_API_KEY_VAR) + } else { + "API key configured".to_string() + })), + ) + .child( + Button::new("reset-key", "Reset") + .icon(IconName::Trash) + .disabled(env_var_set) + .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))), + ) + .into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 2c44a3983bdc3d47848db7cfa5d6d62b35b818c1..55b1f6beb3fcb99c9ba13bd9923e2f102ba922d8 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -13,6 +13,7 @@ use crate::provider::{ anthropic::AnthropicSettings, cloud::{self, ZedDotDevSettings}, copilot_chat::CopilotChatSettings, + deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings, ollama::OllamaSettings, @@ -61,6 +62,7 @@ pub struct AllLanguageModelSettings { pub google: GoogleSettings, pub copilot_chat: CopilotChatSettings, pub lmstudio: LmStudioSettings, + pub deepseek: DeepSeekSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -72,6 +74,7 @@ pub struct AllLanguageModelSettingsContent { #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, pub google: Option, + pub deepseek: Option, pub copilot_chat: Option, } @@ -162,6 +165,12 @@ pub struct LmStudioSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct DeepseekSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum OpenAiSettingsContent { @@ -299,6 +308,18 @@ impl settings::Settings for AllLanguageModelSettings { lmstudio.as_ref().and_then(|s| s.available_models.clone()), ); + // DeepSeek + let deepseek = value.deepseek.clone(); + + merge( + &mut settings.deepseek.api_url, + value.deepseek.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.deepseek.available_models, + deepseek.as_ref().and_then(|s| s.available_models.clone()), + ); + // OpenAI let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) { Some((content, upgraded)) => (Some(content), upgraded), diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index 9ab196d7c19660ba91f1263800a5024f3035e162..1c14366c5f0d37ca9e645e7efd51d37e699b9003 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -128,6 +128,7 @@ pub enum IconName { Ai, AiAnthropic, AiAnthropicHosted, + AiDeepSeek, AiGoogle, AiLmStudio, AiOllama,