deepseek.rs

  1use anyhow::{Result, anyhow};
  2use collections::{BTreeMap, HashMap};
  3use deepseek::DEEPSEEK_API_URL;
  4
  5use futures::Stream;
  6use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream};
  7use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 11    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 12    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 13    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
 14    RateLimiter, Role, StopReason, TokenUsage,
 15};
 16pub use settings::DeepseekAvailableModel as AvailableModel;
 17use settings::{Settings, SettingsStore};
 18use std::pin::Pin;
 19use std::str::FromStr;
 20use std::sync::{Arc, LazyLock};
 21
 22use ui::{List, prelude::*};
 23use ui_input::InputField;
 24use util::ResultExt;
 25use zed_env_vars::{EnvVar, env_var};
 26
 27use crate::ui::ConfiguredApiCard;
 28use crate::{api_key::ApiKeyState, ui::InstructionListItem};
 29
 30const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
 31const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
 32
 33const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY";
 34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 35
 36#[derive(Default)]
 37struct RawToolCall {
 38    id: String,
 39    name: String,
 40    arguments: String,
 41}
 42
 43#[derive(Default, Clone, Debug, PartialEq)]
 44pub struct DeepSeekSettings {
 45    pub api_url: String,
 46    pub available_models: Vec<AvailableModel>,
 47}
 48pub struct DeepSeekLanguageModelProvider {
 49    http_client: Arc<dyn HttpClient>,
 50    state: Entity<State>,
 51}
 52
 53pub struct State {
 54    api_key_state: ApiKeyState,
 55}
 56
 57impl State {
 58    fn is_authenticated(&self) -> bool {
 59        self.api_key_state.has_key()
 60    }
 61
 62    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 63        let api_url = DeepSeekLanguageModelProvider::api_url(cx);
 64        self.api_key_state
 65            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
 66    }
 67
 68    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 69        let api_url = DeepSeekLanguageModelProvider::api_url(cx);
 70        self.api_key_state.load_if_needed(
 71            api_url,
 72            &API_KEY_ENV_VAR,
 73            |this| &mut this.api_key_state,
 74            cx,
 75        )
 76    }
 77}
 78
 79impl DeepSeekLanguageModelProvider {
 80    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 81        let state = cx.new(|cx| {
 82            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 83                let api_url = Self::api_url(cx);
 84                this.api_key_state.handle_url_change(
 85                    api_url,
 86                    &API_KEY_ENV_VAR,
 87                    |this| &mut this.api_key_state,
 88                    cx,
 89                );
 90                cx.notify();
 91            })
 92            .detach();
 93            State {
 94                api_key_state: ApiKeyState::new(Self::api_url(cx)),
 95            }
 96        });
 97
 98        Self { http_client, state }
 99    }
100
101    fn create_language_model(&self, model: deepseek::Model) -> Arc<dyn LanguageModel> {
102        Arc::new(DeepSeekLanguageModel {
103            id: LanguageModelId::from(model.id().to_string()),
104            model,
105            state: self.state.clone(),
106            http_client: self.http_client.clone(),
107            request_limiter: RateLimiter::new(4),
108        })
109    }
110
111    fn settings(cx: &App) -> &DeepSeekSettings {
112        &crate::AllLanguageModelSettings::get_global(cx).deepseek
113    }
114
115    fn api_url(cx: &App) -> SharedString {
116        let api_url = &Self::settings(cx).api_url;
117        if api_url.is_empty() {
118            DEEPSEEK_API_URL.into()
119        } else {
120            SharedString::new(api_url.as_str())
121        }
122    }
123}
124
125impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
126    type ObservableEntity = State;
127
128    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
129        Some(self.state.clone())
130    }
131}
132
133impl LanguageModelProvider for DeepSeekLanguageModelProvider {
134    fn id(&self) -> LanguageModelProviderId {
135        PROVIDER_ID
136    }
137
138    fn name(&self) -> LanguageModelProviderName {
139        PROVIDER_NAME
140    }
141
142    fn icon(&self) -> IconName {
143        IconName::AiDeepSeek
144    }
145
146    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
147        Some(self.create_language_model(deepseek::Model::default()))
148    }
149
150    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
151        Some(self.create_language_model(deepseek::Model::default_fast()))
152    }
153
154    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
155        let mut models = BTreeMap::default();
156
157        models.insert("deepseek-chat", deepseek::Model::Chat);
158        models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
159
160        for available_model in &Self::settings(cx).available_models {
161            models.insert(
162                &available_model.name,
163                deepseek::Model::Custom {
164                    name: available_model.name.clone(),
165                    display_name: available_model.display_name.clone(),
166                    max_tokens: available_model.max_tokens,
167                    max_output_tokens: available_model.max_output_tokens,
168                },
169            );
170        }
171
172        models
173            .into_values()
174            .map(|model| self.create_language_model(model))
175            .collect()
176    }
177
178    fn is_authenticated(&self, cx: &App) -> bool {
179        self.state.read(cx).is_authenticated()
180    }
181
182    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
183        self.state.update(cx, |state, cx| state.authenticate(cx))
184    }
185
186    fn configuration_view(
187        &self,
188        _target_agent: language_model::ConfigurationViewTargetAgent,
189        window: &mut Window,
190        cx: &mut App,
191    ) -> AnyView {
192        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
193            .into()
194    }
195
196    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
197        self.state
198            .update(cx, |state, cx| state.set_api_key(None, cx))
199    }
200}
201
202pub struct DeepSeekLanguageModel {
203    id: LanguageModelId,
204    model: deepseek::Model,
205    state: Entity<State>,
206    http_client: Arc<dyn HttpClient>,
207    request_limiter: RateLimiter,
208}
209
210impl DeepSeekLanguageModel {
211    fn stream_completion(
212        &self,
213        request: deepseek::Request,
214        cx: &AsyncApp,
215    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
216        let http_client = self.http_client.clone();
217
218        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
219            let api_url = DeepSeekLanguageModelProvider::api_url(cx);
220            (state.api_key_state.key(&api_url), api_url)
221        }) else {
222            return future::ready(Err(anyhow!("App state dropped"))).boxed();
223        };
224
225        let future = self.request_limiter.stream(async move {
226            let Some(api_key) = api_key else {
227                return Err(LanguageModelCompletionError::NoApiKey {
228                    provider: PROVIDER_NAME,
229                });
230            };
231            let request =
232                deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
233            let response = request.await?;
234            Ok(response)
235        });
236
237        async move { Ok(future.await?.boxed()) }.boxed()
238    }
239}
240
241impl LanguageModel for DeepSeekLanguageModel {
242    fn id(&self) -> LanguageModelId {
243        self.id.clone()
244    }
245
246    fn name(&self) -> LanguageModelName {
247        LanguageModelName::from(self.model.display_name().to_string())
248    }
249
250    fn provider_id(&self) -> LanguageModelProviderId {
251        PROVIDER_ID
252    }
253
254    fn provider_name(&self) -> LanguageModelProviderName {
255        PROVIDER_NAME
256    }
257
258    fn supports_tools(&self) -> bool {
259        true
260    }
261
262    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
263        true
264    }
265
266    fn supports_images(&self) -> bool {
267        false
268    }
269
270    fn telemetry_id(&self) -> String {
271        format!("deepseek/{}", self.model.id())
272    }
273
274    fn max_token_count(&self) -> u64 {
275        self.model.max_token_count()
276    }
277
278    fn max_output_tokens(&self) -> Option<u64> {
279        self.model.max_output_tokens()
280    }
281
282    fn count_tokens(
283        &self,
284        request: LanguageModelRequest,
285        cx: &App,
286    ) -> BoxFuture<'static, Result<u64>> {
287        cx.background_spawn(async move {
288            let messages = request
289                .messages
290                .into_iter()
291                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
292                    role: match message.role {
293                        Role::User => "user".into(),
294                        Role::Assistant => "assistant".into(),
295                        Role::System => "system".into(),
296                    },
297                    content: Some(message.string_contents()),
298                    name: None,
299                    function_call: None,
300                })
301                .collect::<Vec<_>>();
302
303            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
304        })
305        .boxed()
306    }
307
308    fn stream_completion(
309        &self,
310        request: LanguageModelRequest,
311        cx: &AsyncApp,
312    ) -> BoxFuture<
313        'static,
314        Result<
315            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
316            LanguageModelCompletionError,
317        >,
318    > {
319        let request = into_deepseek(request, &self.model, self.max_output_tokens());
320        let stream = self.stream_completion(request, cx);
321
322        async move {
323            let mapper = DeepSeekEventMapper::new();
324            Ok(mapper.map_stream(stream.await?).boxed())
325        }
326        .boxed()
327    }
328}
329
330pub fn into_deepseek(
331    request: LanguageModelRequest,
332    model: &deepseek::Model,
333    max_output_tokens: Option<u64>,
334) -> deepseek::Request {
335    let is_reasoner = *model == deepseek::Model::Reasoner;
336
337    let mut messages = Vec::new();
338    for message in request.messages {
339        for content in message.content {
340            match content {
341                MessageContent::Text(text) => messages.push(match message.role {
342                    Role::User => deepseek::RequestMessage::User { content: text },
343                    Role::Assistant => deepseek::RequestMessage::Assistant {
344                        content: Some(text),
345                        tool_calls: Vec::new(),
346                    },
347                    Role::System => deepseek::RequestMessage::System { content: text },
348                }),
349                MessageContent::Thinking { .. } => {}
350                MessageContent::RedactedThinking(_) => {}
351                MessageContent::Image(_) => {}
352                MessageContent::ToolUse(tool_use) => {
353                    let tool_call = deepseek::ToolCall {
354                        id: tool_use.id.to_string(),
355                        content: deepseek::ToolCallContent::Function {
356                            function: deepseek::FunctionContent {
357                                name: tool_use.name.to_string(),
358                                arguments: serde_json::to_string(&tool_use.input)
359                                    .unwrap_or_default(),
360                            },
361                        },
362                    };
363
364                    if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
365                        messages.last_mut()
366                    {
367                        tool_calls.push(tool_call);
368                    } else {
369                        messages.push(deepseek::RequestMessage::Assistant {
370                            content: None,
371                            tool_calls: vec![tool_call],
372                        });
373                    }
374                }
375                MessageContent::ToolResult(tool_result) => {
376                    match &tool_result.content {
377                        LanguageModelToolResultContent::Text(text) => {
378                            messages.push(deepseek::RequestMessage::Tool {
379                                content: text.to_string(),
380                                tool_call_id: tool_result.tool_use_id.to_string(),
381                            });
382                        }
383                        LanguageModelToolResultContent::Image(_) => {}
384                    };
385                }
386            }
387        }
388    }
389
390    deepseek::Request {
391        model: model.id().to_string(),
392        messages,
393        stream: true,
394        max_tokens: max_output_tokens,
395        temperature: if is_reasoner {
396            None
397        } else {
398            request.temperature
399        },
400        response_format: None,
401        tools: request
402            .tools
403            .into_iter()
404            .map(|tool| deepseek::ToolDefinition::Function {
405                function: deepseek::FunctionDefinition {
406                    name: tool.name,
407                    description: Some(tool.description),
408                    parameters: Some(tool.input_schema),
409                },
410            })
411            .collect(),
412    }
413}
414
415pub struct DeepSeekEventMapper {
416    tool_calls_by_index: HashMap<usize, RawToolCall>,
417}
418
419impl DeepSeekEventMapper {
420    pub fn new() -> Self {
421        Self {
422            tool_calls_by_index: HashMap::default(),
423        }
424    }
425
426    pub fn map_stream(
427        mut self,
428        events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
429    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
430    {
431        events.flat_map(move |event| {
432            futures::stream::iter(match event {
433                Ok(event) => self.map_event(event),
434                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
435            })
436        })
437    }
438
439    pub fn map_event(
440        &mut self,
441        event: deepseek::StreamResponse,
442    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
443        let Some(choice) = event.choices.first() else {
444            return vec![Err(LanguageModelCompletionError::from(anyhow!(
445                "Response contained no choices"
446            )))];
447        };
448
449        let mut events = Vec::new();
450        if let Some(content) = choice.delta.content.clone() {
451            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
452        }
453
454        if let Some(reasoning_content) = choice.delta.reasoning_content.clone() {
455            events.push(Ok(LanguageModelCompletionEvent::Thinking {
456                text: reasoning_content,
457                signature: None,
458            }));
459        }
460
461        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
462            for tool_call in tool_calls {
463                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
464
465                if let Some(tool_id) = tool_call.id.clone() {
466                    entry.id = tool_id;
467                }
468
469                if let Some(function) = tool_call.function.as_ref() {
470                    if let Some(name) = function.name.clone() {
471                        entry.name = name;
472                    }
473
474                    if let Some(arguments) = function.arguments.clone() {
475                        entry.arguments.push_str(&arguments);
476                    }
477                }
478            }
479        }
480
481        if let Some(usage) = event.usage {
482            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
483                input_tokens: usage.prompt_tokens,
484                output_tokens: usage.completion_tokens,
485                cache_creation_input_tokens: 0,
486                cache_read_input_tokens: 0,
487            })));
488        }
489
490        match choice.finish_reason.as_deref() {
491            Some("stop") => {
492                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
493            }
494            Some("tool_calls") => {
495                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
496                    match serde_json::Value::from_str(&tool_call.arguments) {
497                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
498                            LanguageModelToolUse {
499                                id: tool_call.id.clone().into(),
500                                name: tool_call.name.as_str().into(),
501                                is_input_complete: true,
502                                input,
503                                raw_input: tool_call.arguments.clone(),
504                            },
505                        )),
506                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
507                            id: tool_call.id.clone().into(),
508                            tool_name: tool_call.name.as_str().into(),
509                            raw_input: tool_call.arguments.into(),
510                            json_parse_error: error.to_string(),
511                        }),
512                    }
513                }));
514
515                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
516            }
517            Some(stop_reason) => {
518                log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
519                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
520            }
521            None => {}
522        }
523
524        events
525    }
526}
527
528struct ConfigurationView {
529    api_key_editor: Entity<InputField>,
530    state: Entity<State>,
531    load_credentials_task: Option<Task<()>>,
532}
533
534impl ConfigurationView {
535    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
536        let api_key_editor =
537            cx.new(|cx| InputField::new(window, cx, "sk-00000000000000000000000000000000"));
538
539        cx.observe(&state, |_, _, cx| {
540            cx.notify();
541        })
542        .detach();
543
544        let load_credentials_task = Some(cx.spawn({
545            let state = state.clone();
546            async move |this, cx| {
547                if let Some(task) = state
548                    .update(cx, |state, cx| state.authenticate(cx))
549                    .log_err()
550                {
551                    let _ = task.await;
552                }
553
554                this.update(cx, |this, cx| {
555                    this.load_credentials_task = None;
556                    cx.notify();
557                })
558                .log_err();
559            }
560        }));
561
562        Self {
563            api_key_editor,
564            state,
565            load_credentials_task,
566        }
567    }
568
569    fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
570        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
571        if api_key.is_empty() {
572            return;
573        }
574
575        let state = self.state.clone();
576        cx.spawn(async move |_, cx| {
577            state
578                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
579                .await
580        })
581        .detach_and_log_err(cx);
582    }
583
584    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
585        self.api_key_editor
586            .update(cx, |editor, cx| editor.set_text("", window, cx));
587
588        let state = self.state.clone();
589        cx.spawn(async move |_, cx| {
590            state
591                .update(cx, |state, cx| state.set_api_key(None, cx))?
592                .await
593        })
594        .detach_and_log_err(cx);
595    }
596
597    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
598        !self.state.read(cx).is_authenticated()
599    }
600}
601
602impl Render for ConfigurationView {
603    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
604        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
605        let configured_card_label = if env_var_set {
606            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
607        } else {
608            let api_url = DeepSeekLanguageModelProvider::api_url(cx);
609            if api_url == DEEPSEEK_API_URL {
610                "API key configured".to_string()
611            } else {
612                format!("API key configured for {}", api_url)
613            }
614        };
615
616        if self.load_credentials_task.is_some() {
617            div()
618                .child(Label::new("Loading credentials..."))
619                .into_any_element()
620        } else if self.should_render_editor(cx) {
621            v_flex()
622                .size_full()
623                .on_action(cx.listener(Self::save_api_key))
624                .child(Label::new("To use DeepSeek in Zed, you need an API key:"))
625                .child(
626                    List::new()
627                        .child(InstructionListItem::new(
628                            "Get your API key from the",
629                            Some("DeepSeek console"),
630                            Some("https://platform.deepseek.com/api_keys"),
631                        ))
632                        .child(InstructionListItem::text_only(
633                            "Paste your API key below and hit enter to start using the assistant",
634                        )),
635                )
636                .child(self.api_key_editor.clone())
637                .child(
638                    Label::new(format!(
639                        "Or set the {API_KEY_ENV_VAR_NAME} environment variable."
640                    ))
641                    .size(LabelSize::Small)
642                    .color(Color::Muted),
643                )
644                .into_any_element()
645        } else {
646            ConfiguredApiCard::new(configured_card_label)
647                .disabled(env_var_set)
648                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
649                .into_any_element()
650        }
651    }
652}