deepseek.rs

  1use anyhow::{anyhow, Context as _, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
  7    WhiteSpace,
  8};
  9use http_client::HttpClient;
 10use language_model::{
 11    AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
 12    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 13    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::sync::Arc;
 19use theme::ThemeSettings;
 20use ui::{prelude::*, Icon, IconName};
 21use util::ResultExt;
 22
 23use crate::AllLanguageModelSettings;
 24
 25const PROVIDER_ID: &str = "deepseek";
 26const PROVIDER_NAME: &str = "DeepSeek";
 27const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
 28
 29#[derive(Default, Clone, Debug, PartialEq)]
 30pub struct DeepSeekSettings {
 31    pub api_url: String,
 32    pub available_models: Vec<AvailableModel>,
 33}
 34
 35#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 36pub struct AvailableModel {
 37    pub name: String,
 38    pub display_name: Option<String>,
 39    pub max_tokens: usize,
 40    pub max_output_tokens: Option<u32>,
 41}
 42
 43pub struct DeepSeekLanguageModelProvider {
 44    http_client: Arc<dyn HttpClient>,
 45    state: Entity<State>,
 46}
 47
 48pub struct State {
 49    api_key: Option<String>,
 50    api_key_from_env: bool,
 51    _subscription: Subscription,
 52}
 53
 54impl State {
 55    fn is_authenticated(&self) -> bool {
 56        self.api_key.is_some()
 57    }
 58
 59    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 60        let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
 61        let delete_credentials = cx.delete_credentials(&settings.api_url);
 62        cx.spawn(|this, mut cx| async move {
 63            delete_credentials.await.log_err();
 64            this.update(&mut cx, |this, cx| {
 65                this.api_key = None;
 66                this.api_key_from_env = false;
 67                cx.notify();
 68            })
 69        })
 70    }
 71
 72    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 73        let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
 74        let write_credentials =
 75            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
 76
 77        cx.spawn(|this, mut cx| async move {
 78            write_credentials.await?;
 79            this.update(&mut cx, |this, cx| {
 80                this.api_key = Some(api_key);
 81                cx.notify();
 82            })
 83        })
 84    }
 85
 86    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 87        if self.is_authenticated() {
 88            return Task::ready(Ok(()));
 89        }
 90
 91        let api_url = AllLanguageModelSettings::get_global(cx)
 92            .deepseek
 93            .api_url
 94            .clone();
 95
 96        cx.spawn(|this, mut cx| async move {
 97            let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
 98                (api_key, true)
 99            } else {
100                let (_, api_key) = cx
101                    .update(|cx| cx.read_credentials(&api_url))?
102                    .await?
103                    .ok_or(AuthenticateError::CredentialsNotFound)?;
104                (
105                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
106                    false,
107                )
108            };
109
110            this.update(&mut cx, |this, cx| {
111                this.api_key = Some(api_key);
112                this.api_key_from_env = from_env;
113                cx.notify();
114            })?;
115
116            Ok(())
117        })
118    }
119}
120
121impl DeepSeekLanguageModelProvider {
122    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
123        let state = cx.new(|cx| State {
124            api_key: None,
125            api_key_from_env: false,
126            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
127                cx.notify();
128            }),
129        });
130
131        Self { http_client, state }
132    }
133}
134
135impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
136    type ObservableEntity = State;
137
138    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
139        Some(self.state.clone())
140    }
141}
142
143impl LanguageModelProvider for DeepSeekLanguageModelProvider {
144    fn id(&self) -> LanguageModelProviderId {
145        LanguageModelProviderId(PROVIDER_ID.into())
146    }
147
148    fn name(&self) -> LanguageModelProviderName {
149        LanguageModelProviderName(PROVIDER_NAME.into())
150    }
151
152    fn icon(&self) -> IconName {
153        IconName::AiDeepSeek
154    }
155
156    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
157        let mut models = BTreeMap::default();
158
159        models.insert("deepseek-chat", deepseek::Model::Chat);
160        models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
161
162        for available_model in AllLanguageModelSettings::get_global(cx)
163            .deepseek
164            .available_models
165            .iter()
166        {
167            models.insert(
168                &available_model.name,
169                deepseek::Model::Custom {
170                    name: available_model.name.clone(),
171                    display_name: available_model.display_name.clone(),
172                    max_tokens: available_model.max_tokens,
173                    max_output_tokens: available_model.max_output_tokens,
174                },
175            );
176        }
177
178        models
179            .into_values()
180            .map(|model| {
181                Arc::new(DeepSeekLanguageModel {
182                    id: LanguageModelId::from(model.id().to_string()),
183                    model,
184                    state: self.state.clone(),
185                    http_client: self.http_client.clone(),
186                    request_limiter: RateLimiter::new(4),
187                }) as Arc<dyn LanguageModel>
188            })
189            .collect()
190    }
191
192    fn is_authenticated(&self, cx: &App) -> bool {
193        self.state.read(cx).is_authenticated()
194    }
195
196    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
197        self.state.update(cx, |state, cx| state.authenticate(cx))
198    }
199
200    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
201        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
202            .into()
203    }
204
205    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
206        self.state.update(cx, |state, cx| state.reset_api_key(cx))
207    }
208}
209
210pub struct DeepSeekLanguageModel {
211    id: LanguageModelId,
212    model: deepseek::Model,
213    state: Entity<State>,
214    http_client: Arc<dyn HttpClient>,
215    request_limiter: RateLimiter,
216}
217
218impl DeepSeekLanguageModel {
219    fn stream_completion(
220        &self,
221        request: deepseek::Request,
222        cx: &AsyncApp,
223    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
224        let http_client = self.http_client.clone();
225        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
226            let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
227            (state.api_key.clone(), settings.api_url.clone())
228        }) else {
229            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
230        };
231
232        let future = self.request_limiter.stream(async move {
233            let api_key = api_key.ok_or_else(|| anyhow!("Missing DeepSeek API Key"))?;
234            let request =
235                deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
236            let response = request.await?;
237            Ok(response)
238        });
239
240        async move { Ok(future.await?.boxed()) }.boxed()
241    }
242}
243
244impl LanguageModel for DeepSeekLanguageModel {
245    fn id(&self) -> LanguageModelId {
246        self.id.clone()
247    }
248
249    fn name(&self) -> LanguageModelName {
250        LanguageModelName::from(self.model.display_name().to_string())
251    }
252
253    fn provider_id(&self) -> LanguageModelProviderId {
254        LanguageModelProviderId(PROVIDER_ID.into())
255    }
256
257    fn provider_name(&self) -> LanguageModelProviderName {
258        LanguageModelProviderName(PROVIDER_NAME.into())
259    }
260
261    fn telemetry_id(&self) -> String {
262        format!("deepseek/{}", self.model.id())
263    }
264
265    fn max_token_count(&self) -> usize {
266        self.model.max_token_count()
267    }
268
269    fn max_output_tokens(&self) -> Option<u32> {
270        self.model.max_output_tokens()
271    }
272
273    fn count_tokens(
274        &self,
275        request: LanguageModelRequest,
276        cx: &App,
277    ) -> BoxFuture<'static, Result<usize>> {
278        cx.background_spawn(async move {
279            let messages = request
280                .messages
281                .into_iter()
282                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
283                    role: match message.role {
284                        Role::User => "user".into(),
285                        Role::Assistant => "assistant".into(),
286                        Role::System => "system".into(),
287                    },
288                    content: Some(message.string_contents()),
289                    name: None,
290                    function_call: None,
291                })
292                .collect::<Vec<_>>();
293
294            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
295        })
296        .boxed()
297    }
298
299    fn stream_completion(
300        &self,
301        request: LanguageModelRequest,
302        cx: &AsyncApp,
303    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
304        let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
305        let stream = self.stream_completion(request, cx);
306
307        async move {
308            let stream = stream.await?;
309            Ok(stream
310                .map(|result| {
311                    result.and_then(|response| {
312                        response
313                            .choices
314                            .first()
315                            .ok_or_else(|| anyhow!("Empty response"))
316                            .map(|choice| {
317                                choice
318                                    .delta
319                                    .content
320                                    .clone()
321                                    .unwrap_or_default()
322                                    .map(LanguageModelCompletionEvent::Text)
323                            })
324                    })
325                })
326                .boxed())
327        }
328        .boxed()
329    }
330
331    fn use_any_tool(
332        &self,
333        request: LanguageModelRequest,
334        name: String,
335        description: String,
336        schema: serde_json::Value,
337        cx: &AsyncApp,
338    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
339        let mut deepseek_request =
340            request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
341
342        deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
343            function: deepseek::FunctionDefinition {
344                name: name.clone(),
345                description: Some(description),
346                parameters: Some(schema),
347            },
348        }];
349
350        let response_stream = self.stream_completion(deepseek_request, cx);
351
352        self.request_limiter
353            .run(async move {
354                let stream = response_stream.await?;
355
356                let tool_args_stream = stream
357                    .filter_map(move |response| async move {
358                        match response {
359                            Ok(response) => {
360                                for choice in response.choices {
361                                    if let Some(tool_calls) = choice.delta.tool_calls {
362                                        for tool_call in tool_calls {
363                                            if let Some(function) = tool_call.function {
364                                                if let Some(args) = function.arguments {
365                                                    return Some(Ok(args));
366                                                }
367                                            }
368                                        }
369                                    }
370                                }
371                                None
372                            }
373                            Err(e) => Some(Err(e)),
374                        }
375                    })
376                    .boxed();
377
378                Ok(tool_args_stream)
379            })
380            .boxed()
381    }
382}
383
384struct ConfigurationView {
385    api_key_editor: Entity<Editor>,
386    state: Entity<State>,
387    load_credentials_task: Option<Task<()>>,
388}
389
390impl ConfigurationView {
391    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
392        let api_key_editor = cx.new(|cx| {
393            let mut editor = Editor::single_line(window, cx);
394            editor.set_placeholder_text("sk-00000000000000000000000000000000", cx);
395            editor
396        });
397
398        cx.observe(&state, |_, _, cx| {
399            cx.notify();
400        })
401        .detach();
402
403        let load_credentials_task = Some(cx.spawn({
404            let state = state.clone();
405            |this, mut cx| async move {
406                if let Some(task) = state
407                    .update(&mut cx, |state, cx| state.authenticate(cx))
408                    .log_err()
409                {
410                    let _ = task.await;
411                }
412
413                this.update(&mut cx, |this, cx| {
414                    this.load_credentials_task = None;
415                    cx.notify();
416                })
417                .log_err();
418            }
419        }));
420
421        Self {
422            api_key_editor,
423            state,
424            load_credentials_task,
425        }
426    }
427
428    fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
429        let api_key = self.api_key_editor.read(cx).text(cx);
430        if api_key.is_empty() {
431            return;
432        }
433
434        let state = self.state.clone();
435        cx.spawn(|_, mut cx| async move {
436            state
437                .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
438                .await
439        })
440        .detach_and_log_err(cx);
441
442        cx.notify();
443    }
444
445    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
446        self.api_key_editor
447            .update(cx, |editor, cx| editor.set_text("", window, cx));
448
449        let state = self.state.clone();
450        cx.spawn(|_, mut cx| async move {
451            state
452                .update(&mut cx, |state, cx| state.reset_api_key(cx))?
453                .await
454        })
455        .detach_and_log_err(cx);
456
457        cx.notify();
458    }
459
460    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
461        let settings = ThemeSettings::get_global(cx);
462        let text_style = TextStyle {
463            color: cx.theme().colors().text,
464            font_family: settings.ui_font.family.clone(),
465            font_features: settings.ui_font.features.clone(),
466            font_fallbacks: settings.ui_font.fallbacks.clone(),
467            font_size: rems(0.875).into(),
468            font_weight: settings.ui_font.weight,
469            font_style: FontStyle::Normal,
470            line_height: relative(1.3),
471            background_color: None,
472            underline: None,
473            strikethrough: None,
474            white_space: WhiteSpace::Normal,
475            ..Default::default()
476        };
477        EditorElement::new(
478            &self.api_key_editor,
479            EditorStyle {
480                background: cx.theme().colors().editor_background,
481                local_player: cx.theme().players().local(),
482                text: text_style,
483                ..Default::default()
484            },
485        )
486    }
487
488    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
489        !self.state.read(cx).is_authenticated()
490    }
491}
492
493impl Render for ConfigurationView {
494    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
495        const DEEPSEEK_CONSOLE_URL: &str = "https://platform.deepseek.com/api_keys";
496        const INSTRUCTIONS: [&str; 3] = [
497            "To use DeepSeek in Zed, you need an API key:",
498            "- Get your API key from:",
499            "- Paste it below and press enter:",
500        ];
501
502        let env_var_set = self.state.read(cx).api_key_from_env;
503
504        if self.load_credentials_task.is_some() {
505            div().child(Label::new("Loading credentials...")).into_any()
506        } else if self.should_render_editor(cx) {
507            v_flex()
508                .size_full()
509                .on_action(cx.listener(Self::save_api_key))
510                .child(Label::new(INSTRUCTIONS[0]))
511                .child(
512                    h_flex().child(Label::new(INSTRUCTIONS[1])).child(
513                        Button::new("deepseek_console", DEEPSEEK_CONSOLE_URL)
514                            .style(ButtonStyle::Subtle)
515                            .icon(IconName::ArrowUpRight)
516                            .icon_size(IconSize::XSmall)
517                            .icon_color(Color::Muted)
518                            .on_click(move |_, _window, cx| cx.open_url(DEEPSEEK_CONSOLE_URL)),
519                    ),
520                )
521                .child(Label::new(INSTRUCTIONS[2]))
522                .child(
523                    h_flex()
524                        .w_full()
525                        .my_2()
526                        .px_2()
527                        .py_1()
528                        .bg(cx.theme().colors().editor_background)
529                        .border_1()
530                        .border_color(cx.theme().colors().border_variant)
531                        .rounded_md()
532                        .child(self.render_api_key_editor(cx)),
533                )
534                .child(
535                    Label::new(format!(
536                        "Or set the {} environment variable.",
537                        DEEPSEEK_API_KEY_VAR
538                    ))
539                    .size(LabelSize::Small),
540                )
541                .into_any()
542        } else {
543            h_flex()
544                .size_full()
545                .justify_between()
546                .child(
547                    h_flex()
548                        .gap_1()
549                        .child(Icon::new(IconName::Check).color(Color::Success))
550                        .child(Label::new(if env_var_set {
551                            format!("API key set in {}", DEEPSEEK_API_KEY_VAR)
552                        } else {
553                            "API key configured".to_string()
554                        })),
555                )
556                .child(
557                    Button::new("reset-key", "Reset")
558                        .icon(IconName::Trash)
559                        .disabled(env_var_set)
560                        .on_click(
561                            cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)),
562                        ),
563                )
564                .into_any()
565        }
566    }
567}