deepseek.rs

  1use anyhow::{Result, anyhow};
  2use collections::{BTreeMap, HashMap};
  3use deepseek::DEEPSEEK_API_URL;
  4
  5use futures::Stream;
  6use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream};
  7use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 11    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 12    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 13    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
 14    RateLimiter, Role, StopReason, TokenUsage,
 15};
 16pub use settings::DeepseekAvailableModel as AvailableModel;
 17use settings::{Settings, SettingsStore};
 18use std::pin::Pin;
 19use std::str::FromStr;
 20use std::sync::{Arc, LazyLock};
 21
 22use ui::{List, prelude::*};
 23use ui_input::InputField;
 24use util::ResultExt;
 25use zed_env_vars::{EnvVar, env_var};
 26
 27use crate::ui::ConfiguredApiCard;
 28use crate::{api_key::ApiKeyState, ui::InstructionListItem};
 29
 30const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
 31const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
 32
 33const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY";
 34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 35
 36#[derive(Default)]
 37struct RawToolCall {
 38    id: String,
 39    name: String,
 40    arguments: String,
 41}
 42
 43#[derive(Default, Clone, Debug, PartialEq)]
 44pub struct DeepSeekSettings {
 45    pub api_url: String,
 46    pub available_models: Vec<AvailableModel>,
 47}
 48pub struct DeepSeekLanguageModelProvider {
 49    http_client: Arc<dyn HttpClient>,
 50    state: Entity<State>,
 51}
 52
 53pub struct State {
 54    api_key_state: ApiKeyState,
 55}
 56
 57impl State {
 58    fn is_authenticated(&self) -> bool {
 59        self.api_key_state.has_key()
 60    }
 61
 62    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 63        let api_url = DeepSeekLanguageModelProvider::api_url(cx);
 64        self.api_key_state
 65            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
 66    }
 67
 68    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 69        let api_url = DeepSeekLanguageModelProvider::api_url(cx);
 70        self.api_key_state.load_if_needed(
 71            api_url,
 72            &API_KEY_ENV_VAR,
 73            |this| &mut this.api_key_state,
 74            cx,
 75        )
 76    }
 77}
 78
 79impl DeepSeekLanguageModelProvider {
 80    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 81        let state = cx.new(|cx| {
 82            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 83                let api_url = Self::api_url(cx);
 84                this.api_key_state.handle_url_change(
 85                    api_url,
 86                    &API_KEY_ENV_VAR,
 87                    |this| &mut this.api_key_state,
 88                    cx,
 89                );
 90                cx.notify();
 91            })
 92            .detach();
 93            State {
 94                api_key_state: ApiKeyState::new(Self::api_url(cx)),
 95            }
 96        });
 97
 98        Self { http_client, state }
 99    }
100
101    fn create_language_model(&self, model: deepseek::Model) -> Arc<dyn LanguageModel> {
102        Arc::new(DeepSeekLanguageModel {
103            id: LanguageModelId::from(model.id().to_string()),
104            model,
105            state: self.state.clone(),
106            http_client: self.http_client.clone(),
107            request_limiter: RateLimiter::new(4),
108        })
109    }
110
111    fn settings(cx: &App) -> &DeepSeekSettings {
112        &crate::AllLanguageModelSettings::get_global(cx).deepseek
113    }
114
115    fn api_url(cx: &App) -> SharedString {
116        let api_url = &Self::settings(cx).api_url;
117        if api_url.is_empty() {
118            DEEPSEEK_API_URL.into()
119        } else {
120            SharedString::new(api_url.as_str())
121        }
122    }
123}
124
125impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
126    type ObservableEntity = State;
127
128    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
129        Some(self.state.clone())
130    }
131}
132
133impl LanguageModelProvider for DeepSeekLanguageModelProvider {
134    fn id(&self) -> LanguageModelProviderId {
135        PROVIDER_ID
136    }
137
138    fn name(&self) -> LanguageModelProviderName {
139        PROVIDER_NAME
140    }
141
142    fn icon(&self) -> IconName {
143        IconName::AiDeepSeek
144    }
145
146    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
147        Some(self.create_language_model(deepseek::Model::default()))
148    }
149
150    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
151        Some(self.create_language_model(deepseek::Model::default_fast()))
152    }
153
154    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
155        let mut models = BTreeMap::default();
156
157        models.insert("deepseek-chat", deepseek::Model::Chat);
158        models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
159
160        for available_model in &Self::settings(cx).available_models {
161            models.insert(
162                &available_model.name,
163                deepseek::Model::Custom {
164                    name: available_model.name.clone(),
165                    display_name: available_model.display_name.clone(),
166                    max_tokens: available_model.max_tokens,
167                    max_output_tokens: available_model.max_output_tokens,
168                },
169            );
170        }
171
172        models
173            .into_values()
174            .map(|model| self.create_language_model(model))
175            .collect()
176    }
177
178    fn is_authenticated(&self, cx: &App) -> bool {
179        self.state.read(cx).is_authenticated()
180    }
181
182    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
183        self.state.update(cx, |state, cx| state.authenticate(cx))
184    }
185
186    fn configuration_view(
187        &self,
188        _target_agent: language_model::ConfigurationViewTargetAgent,
189        window: &mut Window,
190        cx: &mut App,
191    ) -> AnyView {
192        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
193            .into()
194    }
195
196    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
197        self.state
198            .update(cx, |state, cx| state.set_api_key(None, cx))
199    }
200}
201
202pub struct DeepSeekLanguageModel {
203    id: LanguageModelId,
204    model: deepseek::Model,
205    state: Entity<State>,
206    http_client: Arc<dyn HttpClient>,
207    request_limiter: RateLimiter,
208}
209
210impl DeepSeekLanguageModel {
211    fn stream_completion(
212        &self,
213        request: deepseek::Request,
214        cx: &AsyncApp,
215    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
216        let http_client = self.http_client.clone();
217
218        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
219            let api_url = DeepSeekLanguageModelProvider::api_url(cx);
220            (state.api_key_state.key(&api_url), api_url)
221        }) else {
222            return future::ready(Err(anyhow!("App state dropped"))).boxed();
223        };
224
225        let future = self.request_limiter.stream(async move {
226            let Some(api_key) = api_key else {
227                return Err(LanguageModelCompletionError::NoApiKey {
228                    provider: PROVIDER_NAME,
229                });
230            };
231            let request =
232                deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
233            let response = request.await?;
234            Ok(response)
235        });
236
237        async move { Ok(future.await?.boxed()) }.boxed()
238    }
239}
240
241impl LanguageModel for DeepSeekLanguageModel {
242    fn id(&self) -> LanguageModelId {
243        self.id.clone()
244    }
245
246    fn name(&self) -> LanguageModelName {
247        LanguageModelName::from(self.model.display_name().to_string())
248    }
249
250    fn provider_id(&self) -> LanguageModelProviderId {
251        PROVIDER_ID
252    }
253
254    fn provider_name(&self) -> LanguageModelProviderName {
255        PROVIDER_NAME
256    }
257
258    fn supports_tools(&self) -> bool {
259        true
260    }
261
262    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
263        true
264    }
265
266    fn supports_images(&self) -> bool {
267        false
268    }
269
270    fn telemetry_id(&self) -> String {
271        format!("deepseek/{}", self.model.id())
272    }
273
274    fn max_token_count(&self) -> u64 {
275        self.model.max_token_count()
276    }
277
278    fn max_output_tokens(&self) -> Option<u64> {
279        self.model.max_output_tokens()
280    }
281
282    fn count_tokens(
283        &self,
284        request: LanguageModelRequest,
285        cx: &App,
286    ) -> BoxFuture<'static, Result<u64>> {
287        cx.background_spawn(async move {
288            let messages = request
289                .messages
290                .into_iter()
291                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
292                    role: match message.role {
293                        Role::User => "user".into(),
294                        Role::Assistant => "assistant".into(),
295                        Role::System => "system".into(),
296                    },
297                    content: Some(message.string_contents()),
298                    name: None,
299                    function_call: None,
300                })
301                .collect::<Vec<_>>();
302
303            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
304        })
305        .boxed()
306    }
307
308    fn stream_completion(
309        &self,
310        request: LanguageModelRequest,
311        cx: &AsyncApp,
312    ) -> BoxFuture<
313        'static,
314        Result<
315            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
316            LanguageModelCompletionError,
317        >,
318    > {
319        let request = into_deepseek(request, &self.model, self.max_output_tokens());
320        let stream = self.stream_completion(request, cx);
321
322        async move {
323            let mapper = DeepSeekEventMapper::new();
324            Ok(mapper.map_stream(stream.await?).boxed())
325        }
326        .boxed()
327    }
328}
329
330pub fn into_deepseek(
331    request: LanguageModelRequest,
332    model: &deepseek::Model,
333    max_output_tokens: Option<u64>,
334) -> deepseek::Request {
335    let is_reasoner = model == &deepseek::Model::Reasoner;
336
337    let mut messages = Vec::new();
338    let mut current_reasoning: Option<String> = None;
339
340    for message in request.messages {
341        for content in message.content {
342            match content {
343                MessageContent::Text(text) => messages.push(match message.role {
344                    Role::User => deepseek::RequestMessage::User { content: text },
345                    Role::Assistant => deepseek::RequestMessage::Assistant {
346                        content: Some(text),
347                        tool_calls: Vec::new(),
348                        reasoning_content: current_reasoning.take(),
349                    },
350                    Role::System => deepseek::RequestMessage::System { content: text },
351                }),
352                MessageContent::Thinking { text, .. } => {
353                    // Accumulate reasoning content for next assistant message
354                    current_reasoning.get_or_insert_default().push_str(&text);
355                }
356                MessageContent::RedactedThinking(_) => {}
357                MessageContent::Image(_) => {}
358                MessageContent::ToolUse(tool_use) => {
359                    let tool_call = deepseek::ToolCall {
360                        id: tool_use.id.to_string(),
361                        content: deepseek::ToolCallContent::Function {
362                            function: deepseek::FunctionContent {
363                                name: tool_use.name.to_string(),
364                                arguments: serde_json::to_string(&tool_use.input)
365                                    .unwrap_or_default(),
366                            },
367                        },
368                    };
369
370                    if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
371                        messages.last_mut()
372                    {
373                        tool_calls.push(tool_call);
374                    } else {
375                        messages.push(deepseek::RequestMessage::Assistant {
376                            content: None,
377                            tool_calls: vec![tool_call],
378                            reasoning_content: current_reasoning.take(),
379                        });
380                    }
381                }
382                MessageContent::ToolResult(tool_result) => {
383                    match &tool_result.content {
384                        LanguageModelToolResultContent::Text(text) => {
385                            messages.push(deepseek::RequestMessage::Tool {
386                                content: text.to_string(),
387                                tool_call_id: tool_result.tool_use_id.to_string(),
388                            });
389                        }
390                        LanguageModelToolResultContent::Image(_) => {}
391                    };
392                }
393            }
394        }
395    }
396
397    deepseek::Request {
398        model: model.id().to_string(),
399        messages,
400        stream: true,
401        max_tokens: max_output_tokens,
402        temperature: if is_reasoner {
403            None
404        } else {
405            request.temperature
406        },
407        response_format: None,
408        tools: request
409            .tools
410            .into_iter()
411            .map(|tool| deepseek::ToolDefinition::Function {
412                function: deepseek::FunctionDefinition {
413                    name: tool.name,
414                    description: Some(tool.description),
415                    parameters: Some(tool.input_schema),
416                },
417            })
418            .collect(),
419    }
420}
421
422pub struct DeepSeekEventMapper {
423    tool_calls_by_index: HashMap<usize, RawToolCall>,
424}
425
426impl DeepSeekEventMapper {
427    pub fn new() -> Self {
428        Self {
429            tool_calls_by_index: HashMap::default(),
430        }
431    }
432
433    pub fn map_stream(
434        mut self,
435        events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
436    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
437    {
438        events.flat_map(move |event| {
439            futures::stream::iter(match event {
440                Ok(event) => self.map_event(event),
441                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
442            })
443        })
444    }
445
446    pub fn map_event(
447        &mut self,
448        event: deepseek::StreamResponse,
449    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
450        let Some(choice) = event.choices.first() else {
451            return vec![Err(LanguageModelCompletionError::from(anyhow!(
452                "Response contained no choices"
453            )))];
454        };
455
456        let mut events = Vec::new();
457        if let Some(content) = choice.delta.content.clone() {
458            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
459        }
460
461        if let Some(reasoning_content) = choice.delta.reasoning_content.clone() {
462            events.push(Ok(LanguageModelCompletionEvent::Thinking {
463                text: reasoning_content,
464                signature: None,
465            }));
466        }
467
468        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
469            for tool_call in tool_calls {
470                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
471
472                if let Some(tool_id) = tool_call.id.clone() {
473                    entry.id = tool_id;
474                }
475
476                if let Some(function) = tool_call.function.as_ref() {
477                    if let Some(name) = function.name.clone() {
478                        entry.name = name;
479                    }
480
481                    if let Some(arguments) = function.arguments.clone() {
482                        entry.arguments.push_str(&arguments);
483                    }
484                }
485            }
486        }
487
488        if let Some(usage) = event.usage {
489            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
490                input_tokens: usage.prompt_tokens,
491                output_tokens: usage.completion_tokens,
492                cache_creation_input_tokens: 0,
493                cache_read_input_tokens: 0,
494            })));
495        }
496
497        match choice.finish_reason.as_deref() {
498            Some("stop") => {
499                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
500            }
501            Some("tool_calls") => {
502                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
503                    match serde_json::Value::from_str(&tool_call.arguments) {
504                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
505                            LanguageModelToolUse {
506                                id: tool_call.id.clone().into(),
507                                name: tool_call.name.as_str().into(),
508                                is_input_complete: true,
509                                input,
510                                raw_input: tool_call.arguments.clone(),
511                                thought_signature: None,
512                            },
513                        )),
514                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
515                            id: tool_call.id.clone().into(),
516                            tool_name: tool_call.name.as_str().into(),
517                            raw_input: tool_call.arguments.into(),
518                            json_parse_error: error.to_string(),
519                        }),
520                    }
521                }));
522
523                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
524            }
525            Some(stop_reason) => {
526                log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
527                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
528            }
529            None => {}
530        }
531
532        events
533    }
534}
535
536struct ConfigurationView {
537    api_key_editor: Entity<InputField>,
538    state: Entity<State>,
539    load_credentials_task: Option<Task<()>>,
540}
541
542impl ConfigurationView {
543    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
544        let api_key_editor =
545            cx.new(|cx| InputField::new(window, cx, "sk-00000000000000000000000000000000"));
546
547        cx.observe(&state, |_, _, cx| {
548            cx.notify();
549        })
550        .detach();
551
552        let load_credentials_task = Some(cx.spawn({
553            let state = state.clone();
554            async move |this, cx| {
555                if let Some(task) = state
556                    .update(cx, |state, cx| state.authenticate(cx))
557                    .log_err()
558                {
559                    let _ = task.await;
560                }
561
562                this.update(cx, |this, cx| {
563                    this.load_credentials_task = None;
564                    cx.notify();
565                })
566                .log_err();
567            }
568        }));
569
570        Self {
571            api_key_editor,
572            state,
573            load_credentials_task,
574        }
575    }
576
577    fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
578        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
579        if api_key.is_empty() {
580            return;
581        }
582
583        let state = self.state.clone();
584        cx.spawn(async move |_, cx| {
585            state
586                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
587                .await
588        })
589        .detach_and_log_err(cx);
590    }
591
592    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
593        self.api_key_editor
594            .update(cx, |editor, cx| editor.set_text("", window, cx));
595
596        let state = self.state.clone();
597        cx.spawn(async move |_, cx| {
598            state
599                .update(cx, |state, cx| state.set_api_key(None, cx))?
600                .await
601        })
602        .detach_and_log_err(cx);
603    }
604
605    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
606        !self.state.read(cx).is_authenticated()
607    }
608}
609
610impl Render for ConfigurationView {
611    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
612        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
613        let configured_card_label = if env_var_set {
614            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
615        } else {
616            let api_url = DeepSeekLanguageModelProvider::api_url(cx);
617            if api_url == DEEPSEEK_API_URL {
618                "API key configured".to_string()
619            } else {
620                format!("API key configured for {}", api_url)
621            }
622        };
623
624        if self.load_credentials_task.is_some() {
625            div()
626                .child(Label::new("Loading credentials..."))
627                .into_any_element()
628        } else if self.should_render_editor(cx) {
629            v_flex()
630                .size_full()
631                .on_action(cx.listener(Self::save_api_key))
632                .child(Label::new("To use DeepSeek in Zed, you need an API key:"))
633                .child(
634                    List::new()
635                        .child(InstructionListItem::new(
636                            "Get your API key from the",
637                            Some("DeepSeek console"),
638                            Some("https://platform.deepseek.com/api_keys"),
639                        ))
640                        .child(InstructionListItem::text_only(
641                            "Paste your API key below and hit enter to start using the assistant",
642                        )),
643                )
644                .child(self.api_key_editor.clone())
645                .child(
646                    Label::new(format!(
647                        "Or set the {API_KEY_ENV_VAR_NAME} environment variable."
648                    ))
649                    .size(LabelSize::Small)
650                    .color(Color::Muted),
651                )
652                .into_any_element()
653        } else {
654            ConfiguredApiCard::new(configured_card_label)
655                .disabled(env_var_set)
656                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
657                .into_any_element()
658        }
659    }
660}