1use anyhow::{anyhow, Context as _, Result};
2use collections::BTreeMap;
3use credentials_provider::CredentialsProvider;
4use editor::{Editor, EditorElement, EditorStyle};
5use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
6use gpui::{
7 AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
8 WhiteSpace,
9};
10use http_client::HttpClient;
11use language_model::{
12 AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
13 LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
14 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
15};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::sync::Arc;
20use theme::ThemeSettings;
21use ui::{prelude::*, Icon, IconName};
22use util::ResultExt;
23
24use crate::AllLanguageModelSettings;
25
26const PROVIDER_ID: &str = "deepseek";
27const PROVIDER_NAME: &str = "DeepSeek";
28const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
29
30#[derive(Default, Clone, Debug, PartialEq)]
31pub struct DeepSeekSettings {
32 pub api_url: String,
33 pub available_models: Vec<AvailableModel>,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 pub name: String,
39 pub display_name: Option<String>,
40 pub max_tokens: usize,
41 pub max_output_tokens: Option<u32>,
42}
43
44pub struct DeepSeekLanguageModelProvider {
45 http_client: Arc<dyn HttpClient>,
46 state: Entity<State>,
47}
48
49pub struct State {
50 api_key: Option<String>,
51 api_key_from_env: bool,
52 _subscription: Subscription,
53}
54
55impl State {
56 fn is_authenticated(&self) -> bool {
57 self.api_key.is_some()
58 }
59
60 fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
61 let credentials_provider = <dyn CredentialsProvider>::global(cx);
62 let api_url = AllLanguageModelSettings::get_global(cx)
63 .deepseek
64 .api_url
65 .clone();
66 cx.spawn(|this, mut cx| async move {
67 credentials_provider
68 .delete_credentials(&api_url, &cx)
69 .await
70 .log_err();
71 this.update(&mut cx, |this, cx| {
72 this.api_key = None;
73 this.api_key_from_env = false;
74 cx.notify();
75 })
76 })
77 }
78
79 fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
80 let credentials_provider = <dyn CredentialsProvider>::global(cx);
81 let api_url = AllLanguageModelSettings::get_global(cx)
82 .deepseek
83 .api_url
84 .clone();
85 cx.spawn(|this, mut cx| async move {
86 credentials_provider
87 .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
88 .await?;
89 this.update(&mut cx, |this, cx| {
90 this.api_key = Some(api_key);
91 cx.notify();
92 })
93 })
94 }
95
96 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
97 if self.is_authenticated() {
98 return Task::ready(Ok(()));
99 }
100
101 let credentials_provider = <dyn CredentialsProvider>::global(cx);
102 let api_url = AllLanguageModelSettings::get_global(cx)
103 .deepseek
104 .api_url
105 .clone();
106 cx.spawn(|this, mut cx| async move {
107 let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
108 (api_key, true)
109 } else {
110 let (_, api_key) = credentials_provider
111 .read_credentials(&api_url, &cx)
112 .await?
113 .ok_or(AuthenticateError::CredentialsNotFound)?;
114 (
115 String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
116 false,
117 )
118 };
119
120 this.update(&mut cx, |this, cx| {
121 this.api_key = Some(api_key);
122 this.api_key_from_env = from_env;
123 cx.notify();
124 })?;
125
126 Ok(())
127 })
128 }
129}
130
131impl DeepSeekLanguageModelProvider {
132 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
133 let state = cx.new(|cx| State {
134 api_key: None,
135 api_key_from_env: false,
136 _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
137 cx.notify();
138 }),
139 });
140
141 Self { http_client, state }
142 }
143}
144
145impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
146 type ObservableEntity = State;
147
148 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
149 Some(self.state.clone())
150 }
151}
152
153impl LanguageModelProvider for DeepSeekLanguageModelProvider {
154 fn id(&self) -> LanguageModelProviderId {
155 LanguageModelProviderId(PROVIDER_ID.into())
156 }
157
158 fn name(&self) -> LanguageModelProviderName {
159 LanguageModelProviderName(PROVIDER_NAME.into())
160 }
161
162 fn icon(&self) -> IconName {
163 IconName::AiDeepSeek
164 }
165
166 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
167 let mut models = BTreeMap::default();
168
169 models.insert("deepseek-chat", deepseek::Model::Chat);
170 models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
171
172 for available_model in AllLanguageModelSettings::get_global(cx)
173 .deepseek
174 .available_models
175 .iter()
176 {
177 models.insert(
178 &available_model.name,
179 deepseek::Model::Custom {
180 name: available_model.name.clone(),
181 display_name: available_model.display_name.clone(),
182 max_tokens: available_model.max_tokens,
183 max_output_tokens: available_model.max_output_tokens,
184 },
185 );
186 }
187
188 models
189 .into_values()
190 .map(|model| {
191 Arc::new(DeepSeekLanguageModel {
192 id: LanguageModelId::from(model.id().to_string()),
193 model,
194 state: self.state.clone(),
195 http_client: self.http_client.clone(),
196 request_limiter: RateLimiter::new(4),
197 }) as Arc<dyn LanguageModel>
198 })
199 .collect()
200 }
201
202 fn is_authenticated(&self, cx: &App) -> bool {
203 self.state.read(cx).is_authenticated()
204 }
205
206 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
207 self.state.update(cx, |state, cx| state.authenticate(cx))
208 }
209
210 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
211 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
212 .into()
213 }
214
215 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
216 self.state.update(cx, |state, cx| state.reset_api_key(cx))
217 }
218}
219
220pub struct DeepSeekLanguageModel {
221 id: LanguageModelId,
222 model: deepseek::Model,
223 state: Entity<State>,
224 http_client: Arc<dyn HttpClient>,
225 request_limiter: RateLimiter,
226}
227
228impl DeepSeekLanguageModel {
229 fn stream_completion(
230 &self,
231 request: deepseek::Request,
232 cx: &AsyncApp,
233 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
234 let http_client = self.http_client.clone();
235 let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
236 let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
237 (state.api_key.clone(), settings.api_url.clone())
238 }) else {
239 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
240 };
241
242 let future = self.request_limiter.stream(async move {
243 let api_key = api_key.ok_or_else(|| anyhow!("Missing DeepSeek API Key"))?;
244 let request =
245 deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
246 let response = request.await?;
247 Ok(response)
248 });
249
250 async move { Ok(future.await?.boxed()) }.boxed()
251 }
252}
253
254impl LanguageModel for DeepSeekLanguageModel {
255 fn id(&self) -> LanguageModelId {
256 self.id.clone()
257 }
258
259 fn name(&self) -> LanguageModelName {
260 LanguageModelName::from(self.model.display_name().to_string())
261 }
262
263 fn provider_id(&self) -> LanguageModelProviderId {
264 LanguageModelProviderId(PROVIDER_ID.into())
265 }
266
267 fn provider_name(&self) -> LanguageModelProviderName {
268 LanguageModelProviderName(PROVIDER_NAME.into())
269 }
270
271 fn telemetry_id(&self) -> String {
272 format!("deepseek/{}", self.model.id())
273 }
274
275 fn max_token_count(&self) -> usize {
276 self.model.max_token_count()
277 }
278
279 fn max_output_tokens(&self) -> Option<u32> {
280 self.model.max_output_tokens()
281 }
282
283 fn count_tokens(
284 &self,
285 request: LanguageModelRequest,
286 cx: &App,
287 ) -> BoxFuture<'static, Result<usize>> {
288 cx.background_spawn(async move {
289 let messages = request
290 .messages
291 .into_iter()
292 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
293 role: match message.role {
294 Role::User => "user".into(),
295 Role::Assistant => "assistant".into(),
296 Role::System => "system".into(),
297 },
298 content: Some(message.string_contents()),
299 name: None,
300 function_call: None,
301 })
302 .collect::<Vec<_>>();
303
304 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
305 })
306 .boxed()
307 }
308
309 fn stream_completion(
310 &self,
311 request: LanguageModelRequest,
312 cx: &AsyncApp,
313 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
314 let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
315 let stream = self.stream_completion(request, cx);
316
317 async move {
318 let stream = stream.await?;
319 Ok(stream
320 .map(|result| {
321 result.and_then(|response| {
322 response
323 .choices
324 .first()
325 .ok_or_else(|| anyhow!("Empty response"))
326 .map(|choice| {
327 choice
328 .delta
329 .content
330 .clone()
331 .unwrap_or_default()
332 .map(LanguageModelCompletionEvent::Text)
333 })
334 })
335 })
336 .boxed())
337 }
338 .boxed()
339 }
340
341 fn use_any_tool(
342 &self,
343 request: LanguageModelRequest,
344 name: String,
345 description: String,
346 schema: serde_json::Value,
347 cx: &AsyncApp,
348 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
349 let mut deepseek_request =
350 request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
351
352 deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
353 function: deepseek::FunctionDefinition {
354 name: name.clone(),
355 description: Some(description),
356 parameters: Some(schema),
357 },
358 }];
359
360 let response_stream = self.stream_completion(deepseek_request, cx);
361
362 self.request_limiter
363 .run(async move {
364 let stream = response_stream.await?;
365
366 let tool_args_stream = stream
367 .filter_map(move |response| async move {
368 match response {
369 Ok(response) => {
370 for choice in response.choices {
371 if let Some(tool_calls) = choice.delta.tool_calls {
372 for tool_call in tool_calls {
373 if let Some(function) = tool_call.function {
374 if let Some(args) = function.arguments {
375 return Some(Ok(args));
376 }
377 }
378 }
379 }
380 }
381 None
382 }
383 Err(e) => Some(Err(e)),
384 }
385 })
386 .boxed();
387
388 Ok(tool_args_stream)
389 })
390 .boxed()
391 }
392}
393
394struct ConfigurationView {
395 api_key_editor: Entity<Editor>,
396 state: Entity<State>,
397 load_credentials_task: Option<Task<()>>,
398}
399
400impl ConfigurationView {
401 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
402 let api_key_editor = cx.new(|cx| {
403 let mut editor = Editor::single_line(window, cx);
404 editor.set_placeholder_text("sk-00000000000000000000000000000000", cx);
405 editor
406 });
407
408 cx.observe(&state, |_, _, cx| {
409 cx.notify();
410 })
411 .detach();
412
413 let load_credentials_task = Some(cx.spawn({
414 let state = state.clone();
415 |this, mut cx| async move {
416 if let Some(task) = state
417 .update(&mut cx, |state, cx| state.authenticate(cx))
418 .log_err()
419 {
420 let _ = task.await;
421 }
422
423 this.update(&mut cx, |this, cx| {
424 this.load_credentials_task = None;
425 cx.notify();
426 })
427 .log_err();
428 }
429 }));
430
431 Self {
432 api_key_editor,
433 state,
434 load_credentials_task,
435 }
436 }
437
438 fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
439 let api_key = self.api_key_editor.read(cx).text(cx);
440 if api_key.is_empty() {
441 return;
442 }
443
444 let state = self.state.clone();
445 cx.spawn(|_, mut cx| async move {
446 state
447 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
448 .await
449 })
450 .detach_and_log_err(cx);
451
452 cx.notify();
453 }
454
455 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
456 self.api_key_editor
457 .update(cx, |editor, cx| editor.set_text("", window, cx));
458
459 let state = self.state.clone();
460 cx.spawn(|_, mut cx| async move {
461 state
462 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
463 .await
464 })
465 .detach_and_log_err(cx);
466
467 cx.notify();
468 }
469
470 fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
471 let settings = ThemeSettings::get_global(cx);
472 let text_style = TextStyle {
473 color: cx.theme().colors().text,
474 font_family: settings.ui_font.family.clone(),
475 font_features: settings.ui_font.features.clone(),
476 font_fallbacks: settings.ui_font.fallbacks.clone(),
477 font_size: rems(0.875).into(),
478 font_weight: settings.ui_font.weight,
479 font_style: FontStyle::Normal,
480 line_height: relative(1.3),
481 background_color: None,
482 underline: None,
483 strikethrough: None,
484 white_space: WhiteSpace::Normal,
485 ..Default::default()
486 };
487 EditorElement::new(
488 &self.api_key_editor,
489 EditorStyle {
490 background: cx.theme().colors().editor_background,
491 local_player: cx.theme().players().local(),
492 text: text_style,
493 ..Default::default()
494 },
495 )
496 }
497
498 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
499 !self.state.read(cx).is_authenticated()
500 }
501}
502
503impl Render for ConfigurationView {
504 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
505 const DEEPSEEK_CONSOLE_URL: &str = "https://platform.deepseek.com/api_keys";
506 const INSTRUCTIONS: [&str; 3] = [
507 "To use DeepSeek in Zed, you need an API key:",
508 "- Get your API key from:",
509 "- Paste it below and press enter:",
510 ];
511
512 let env_var_set = self.state.read(cx).api_key_from_env;
513
514 if self.load_credentials_task.is_some() {
515 div().child(Label::new("Loading credentials...")).into_any()
516 } else if self.should_render_editor(cx) {
517 v_flex()
518 .size_full()
519 .on_action(cx.listener(Self::save_api_key))
520 .child(Label::new(INSTRUCTIONS[0]))
521 .child(
522 h_flex().child(Label::new(INSTRUCTIONS[1])).child(
523 Button::new("deepseek_console", DEEPSEEK_CONSOLE_URL)
524 .style(ButtonStyle::Subtle)
525 .icon(IconName::ArrowUpRight)
526 .icon_size(IconSize::XSmall)
527 .icon_color(Color::Muted)
528 .on_click(move |_, _window, cx| cx.open_url(DEEPSEEK_CONSOLE_URL)),
529 ),
530 )
531 .child(Label::new(INSTRUCTIONS[2]))
532 .child(
533 h_flex()
534 .w_full()
535 .my_2()
536 .px_2()
537 .py_1()
538 .bg(cx.theme().colors().editor_background)
539 .border_1()
540 .border_color(cx.theme().colors().border_variant)
541 .rounded_md()
542 .child(self.render_api_key_editor(cx)),
543 )
544 .child(
545 Label::new(format!(
546 "Or set the {} environment variable.",
547 DEEPSEEK_API_KEY_VAR
548 ))
549 .size(LabelSize::Small),
550 )
551 .into_any()
552 } else {
553 h_flex()
554 .size_full()
555 .justify_between()
556 .child(
557 h_flex()
558 .gap_1()
559 .child(Icon::new(IconName::Check).color(Color::Success))
560 .child(Label::new(if env_var_set {
561 format!("API key set in {}", DEEPSEEK_API_KEY_VAR)
562 } else {
563 "API key configured".to_string()
564 })),
565 )
566 .child(
567 Button::new("reset-key", "Reset")
568 .icon(IconName::Trash)
569 .disabled(env_var_set)
570 .on_click(
571 cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)),
572 ),
573 )
574 .into_any()
575 }
576 }
577}