1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::{LanguageModelAuthConfig, OAuthConfig};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
20 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 // First check cached state
183 if self.state.read(cx).is_authenticated {
184 return true;
185 }
186
187 // Also check env var dynamically (in case migration happened after provider creation)
188 if let Some(ref auth_config) = self.auth_config {
189 if let Some(ref env_var_name) = auth_config.env_var {
190 let provider_id_string = self.provider_id_string();
191 let env_var_allowed = ExtensionSettings::get_global(cx)
192 .allowed_env_var_providers
193 .contains(provider_id_string.as_str());
194
195 if env_var_allowed {
196 if let Ok(value) = std::env::var(env_var_name) {
197 if !value.is_empty() {
198 return true;
199 }
200 }
201 }
202 }
203 }
204
205 false
206 }
207
208 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
209 let extension = self.extension.clone();
210 let provider_id = self.provider_info.id.clone();
211 let state = self.state.clone();
212
213 cx.spawn(async move |cx| {
214 let result = extension
215 .call(|extension, store| {
216 async move {
217 extension
218 .call_llm_provider_authenticate(store, &provider_id)
219 .await
220 }
221 .boxed()
222 })
223 .await;
224
225 match result {
226 Ok(Ok(Ok(()))) => {
227 cx.update(|cx| {
228 state.update(cx, |state, _| {
229 state.is_authenticated = true;
230 });
231 })?;
232 Ok(())
233 }
234 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
235 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
236 Err(e) => Err(AuthenticateError::Other(e)),
237 }
238 })
239 }
240
241 fn configuration_view(
242 &self,
243 _target_agent: ConfigurationViewTargetAgent,
244 window: &mut Window,
245 cx: &mut App,
246 ) -> AnyView {
247 let credential_key = self.credential_key();
248 let extension = self.extension.clone();
249 let extension_provider_id = self.provider_info.id.clone();
250 let full_provider_id = self.provider_id_string();
251 let state = self.state.clone();
252 let auth_config = self.auth_config.clone();
253
254 cx.new(|cx| {
255 ExtensionProviderConfigurationView::new(
256 credential_key,
257 extension,
258 extension_provider_id,
259 full_provider_id,
260 auth_config,
261 state,
262 window,
263 cx,
264 )
265 })
266 .into()
267 }
268
269 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
270 let extension = self.extension.clone();
271 let provider_id = self.provider_info.id.clone();
272 let state = self.state.clone();
273 let credential_key = self.credential_key();
274
275 let credentials_provider = <dyn CredentialsProvider>::global(cx);
276
277 cx.spawn(async move |cx| {
278 // Delete from system keychain
279 credentials_provider
280 .delete_credentials(&credential_key, cx)
281 .await
282 .log_err();
283
284 // Call extension's reset_credentials
285 let result = extension
286 .call(|extension, store| {
287 async move {
288 extension
289 .call_llm_provider_reset_credentials(store, &provider_id)
290 .await
291 }
292 .boxed()
293 })
294 .await;
295
296 // Update state
297 cx.update(|cx| {
298 state.update(cx, |state, _| {
299 state.is_authenticated = false;
300 });
301 })?;
302
303 match result {
304 Ok(Ok(Ok(()))) => Ok(()),
305 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
306 Ok(Err(e)) => Err(e),
307 Err(e) => Err(e),
308 }
309 })
310 }
311}
312
313impl LanguageModelProviderState for ExtensionLanguageModelProvider {
314 type ObservableEntity = ExtensionLlmProviderState;
315
316 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
317 Some(self.state.clone())
318 }
319
320 fn subscribe<T: 'static>(
321 &self,
322 cx: &mut Context<T>,
323 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
324 ) -> Option<Subscription> {
325 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
326 }
327}
328
329/// Configuration view for extension-based LLM providers.
330struct ExtensionProviderConfigurationView {
331 credential_key: String,
332 extension: WasmExtension,
333 extension_provider_id: String,
334 full_provider_id: String,
335 auth_config: Option<LanguageModelAuthConfig>,
336 state: Entity<ExtensionLlmProviderState>,
337 settings_markdown: Option<Entity<Markdown>>,
338 api_key_editor: Entity<Editor>,
339 loading_settings: bool,
340 loading_credentials: bool,
341 oauth_in_progress: bool,
342 oauth_error: Option<String>,
343 device_user_code: Option<String>,
344 _subscriptions: Vec<Subscription>,
345}
346
347impl ExtensionProviderConfigurationView {
348 fn new(
349 credential_key: String,
350 extension: WasmExtension,
351 extension_provider_id: String,
352 full_provider_id: String,
353 auth_config: Option<LanguageModelAuthConfig>,
354 state: Entity<ExtensionLlmProviderState>,
355 window: &mut Window,
356 cx: &mut Context<Self>,
357 ) -> Self {
358 // Subscribe to state changes
359 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
360 cx.notify();
361 });
362
363 // Create API key editor
364 let api_key_editor = cx.new(|cx| {
365 let mut editor = Editor::single_line(window, cx);
366 editor.set_placeholder_text("Enter API key...", window, cx);
367 editor
368 });
369
370 let mut this = Self {
371 credential_key,
372 extension,
373 extension_provider_id,
374 full_provider_id,
375 auth_config,
376 state,
377 settings_markdown: None,
378 api_key_editor,
379 loading_settings: true,
380 loading_credentials: true,
381 oauth_in_progress: false,
382 oauth_error: None,
383 device_user_code: None,
384 _subscriptions: vec![state_subscription],
385 };
386
387 // Load settings text from extension
388 this.load_settings_text(cx);
389
390 // Load existing credentials
391 this.load_credentials(cx);
392
393 this
394 }
395
396 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
397 let extension = self.extension.clone();
398 let provider_id = self.extension_provider_id.clone();
399
400 cx.spawn(async move |this, cx| {
401 let result = extension
402 .call({
403 let provider_id = provider_id.clone();
404 |ext, store| {
405 async move {
406 ext.call_llm_provider_settings_markdown(store, &provider_id)
407 .await
408 }
409 .boxed()
410 }
411 })
412 .await;
413
414 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
415
416 this.update(cx, |this, cx| {
417 this.loading_settings = false;
418 if let Some(text) = settings_text {
419 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
420 this.settings_markdown = Some(markdown);
421 }
422 cx.notify();
423 })
424 .log_err();
425 })
426 .detach();
427 }
428
429 fn load_credentials(&mut self, cx: &mut Context<Self>) {
430 let credential_key = self.credential_key.clone();
431 let credentials_provider = <dyn CredentialsProvider>::global(cx);
432 let state = self.state.clone();
433
434 // Check if we should use env var (already set in state during provider construction)
435 let api_key_from_env = self.state.read(cx).api_key_from_env;
436
437 cx.spawn(async move |this, cx| {
438 // If using env var, we're already authenticated
439 if api_key_from_env {
440 this.update(cx, |this, cx| {
441 this.loading_credentials = false;
442 cx.notify();
443 })
444 .log_err();
445 return;
446 }
447
448 let credentials = credentials_provider
449 .read_credentials(&credential_key, cx)
450 .await
451 .log_err()
452 .flatten();
453
454 let has_credentials = credentials.is_some();
455
456 // Update authentication state based on stored credentials
457 let _ = cx.update(|cx| {
458 state.update(cx, |state, cx| {
459 state.is_authenticated = has_credentials;
460 cx.notify();
461 });
462 });
463
464 this.update(cx, |this, cx| {
465 this.loading_credentials = false;
466 cx.notify();
467 })
468 .log_err();
469 })
470 .detach();
471 }
472
473 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
474 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
475 let env_var_name = match &self.auth_config {
476 Some(config) => config.env_var.clone(),
477 None => return,
478 };
479
480 let state = self.state.clone();
481 let currently_allowed = self.state.read(cx).env_var_allowed;
482
483 // Update settings file
484 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
485 let providers = settings
486 .extension
487 .allowed_env_var_providers
488 .get_or_insert_with(Vec::new);
489
490 if currently_allowed {
491 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
492 } else {
493 if !providers
494 .iter()
495 .any(|id| id.as_ref() == full_provider_id.as_ref())
496 {
497 providers.push(full_provider_id.clone());
498 }
499 }
500 });
501
502 // Update local state
503 let new_allowed = !currently_allowed;
504 let new_from_env = if new_allowed {
505 if let Some(var_name) = &env_var_name {
506 if let Ok(value) = std::env::var(var_name) {
507 !value.is_empty()
508 } else {
509 false
510 }
511 } else {
512 false
513 }
514 } else {
515 false
516 };
517
518 state.update(cx, |state, cx| {
519 state.env_var_allowed = new_allowed;
520 state.api_key_from_env = new_from_env;
521 if new_from_env {
522 state.is_authenticated = true;
523 }
524 cx.notify();
525 });
526
527 // If env var is being enabled, clear any stored keychain credentials
528 // so there's only one source of truth for the API key
529 if new_allowed {
530 let credential_key = self.credential_key.clone();
531 let credentials_provider = <dyn CredentialsProvider>::global(cx);
532 cx.spawn(async move |_this, cx| {
533 credentials_provider
534 .delete_credentials(&credential_key, cx)
535 .await
536 .log_err();
537 })
538 .detach();
539 }
540
541 // If env var is being disabled, reload credentials from keychain
542 if !new_allowed {
543 self.reload_keychain_credentials(cx);
544 }
545
546 cx.notify();
547 }
548
549 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
550 let credential_key = self.credential_key.clone();
551 let credentials_provider = <dyn CredentialsProvider>::global(cx);
552 let state = self.state.clone();
553
554 cx.spawn(async move |_this, cx| {
555 let credentials = credentials_provider
556 .read_credentials(&credential_key, cx)
557 .await
558 .log_err()
559 .flatten();
560
561 let has_credentials = credentials.is_some();
562
563 let _ = cx.update(|cx| {
564 state.update(cx, |state, cx| {
565 state.is_authenticated = has_credentials;
566 cx.notify();
567 });
568 });
569 })
570 .detach();
571 }
572
573 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
574 let api_key = self.api_key_editor.read(cx).text(cx);
575 if api_key.is_empty() {
576 return;
577 }
578
579 // Clear the editor
580 self.api_key_editor
581 .update(cx, |editor, cx| editor.set_text("", window, cx));
582
583 let credential_key = self.credential_key.clone();
584 let credentials_provider = <dyn CredentialsProvider>::global(cx);
585 let state = self.state.clone();
586
587 cx.spawn(async move |_this, cx| {
588 // Store in system keychain
589 credentials_provider
590 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
591 .await
592 .log_err();
593
594 // Update state to authenticated
595 let _ = cx.update(|cx| {
596 state.update(cx, |state, cx| {
597 state.is_authenticated = true;
598 cx.notify();
599 });
600 });
601 })
602 .detach();
603 }
604
605 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
606 // Clear the editor
607 self.api_key_editor
608 .update(cx, |editor, cx| editor.set_text("", window, cx));
609
610 let credential_key = self.credential_key.clone();
611 let credentials_provider = <dyn CredentialsProvider>::global(cx);
612 let state = self.state.clone();
613
614 cx.spawn(async move |_this, cx| {
615 // Delete from system keychain
616 credentials_provider
617 .delete_credentials(&credential_key, cx)
618 .await
619 .log_err();
620
621 // Update state to unauthenticated
622 let _ = cx.update(|cx| {
623 state.update(cx, |state, cx| {
624 state.is_authenticated = false;
625 cx.notify();
626 });
627 });
628 })
629 .detach();
630 }
631
632 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
633 if self.oauth_in_progress {
634 return;
635 }
636
637 self.oauth_in_progress = true;
638 self.oauth_error = None;
639 self.device_user_code = None;
640 cx.notify();
641
642 let extension = self.extension.clone();
643 let provider_id = self.extension_provider_id.clone();
644 let state = self.state.clone();
645
646 cx.spawn(async move |this, cx| {
647 // Step 1: Start device flow - opens browser and returns user code
648 let start_result = extension
649 .call({
650 let provider_id = provider_id.clone();
651 |ext, store| {
652 async move {
653 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
654 .await
655 }
656 .boxed()
657 }
658 })
659 .await;
660
661 let user_code = match start_result {
662 Ok(Ok(Ok(code))) => code,
663 Ok(Ok(Err(e))) => {
664 log::error!("Device flow start failed: {}", e);
665 this.update(cx, |this, cx| {
666 this.oauth_in_progress = false;
667 this.oauth_error = Some(e);
668 cx.notify();
669 })
670 .log_err();
671 return;
672 }
673 Ok(Err(e)) | Err(e) => {
674 log::error!("Device flow start error: {}", e);
675 this.update(cx, |this, cx| {
676 this.oauth_in_progress = false;
677 this.oauth_error = Some(e.to_string());
678 cx.notify();
679 })
680 .log_err();
681 return;
682 }
683 };
684
685 // Update UI to show the user code before polling
686 this.update(cx, |this, cx| {
687 this.device_user_code = Some(user_code);
688 cx.notify();
689 })
690 .log_err();
691
692 // Step 2: Poll for authentication completion
693 let poll_result = extension
694 .call({
695 let provider_id = provider_id.clone();
696 |ext, store| {
697 async move {
698 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
699 .await
700 }
701 .boxed()
702 }
703 })
704 .await;
705
706 let error_message = match poll_result {
707 Ok(Ok(Ok(()))) => {
708 let _ = cx.update(|cx| {
709 state.update(cx, |state, cx| {
710 state.is_authenticated = true;
711 cx.notify();
712 });
713 });
714 None
715 }
716 Ok(Ok(Err(e))) => {
717 log::error!("Device flow poll failed: {}", e);
718 Some(e)
719 }
720 Ok(Err(e)) | Err(e) => {
721 log::error!("Device flow poll error: {}", e);
722 Some(e.to_string())
723 }
724 };
725
726 this.update(cx, |this, cx| {
727 this.oauth_in_progress = false;
728 this.oauth_error = error_message;
729 this.device_user_code = None;
730 cx.notify();
731 })
732 .log_err();
733 })
734 .detach();
735 }
736
737 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
738 self.state.read(cx).is_authenticated
739 }
740
741 fn has_oauth_config(&self) -> bool {
742 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
743 }
744
745 fn oauth_config(&self) -> Option<&OAuthConfig> {
746 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
747 }
748
749 fn has_api_key_config(&self) -> bool {
750 // API key is available if there's a credential_label or no oauth-only config
751 self.auth_config
752 .as_ref()
753 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
754 .unwrap_or(true)
755 }
756}
757
758impl gpui::Render for ExtensionProviderConfigurationView {
759 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
760 let is_loading = self.loading_settings || self.loading_credentials;
761 let is_authenticated = self.is_authenticated(cx);
762 let env_var_allowed = self.state.read(cx).env_var_allowed;
763 let api_key_from_env = self.state.read(cx).api_key_from_env;
764 let has_oauth = self.has_oauth_config();
765 let has_api_key = self.has_api_key_config();
766
767 if is_loading {
768 return v_flex()
769 .gap_2()
770 .child(Label::new("Loading...").color(Color::Muted))
771 .into_any_element();
772 }
773
774 let mut content = v_flex().gap_4().size_full();
775
776 // Render settings markdown if available
777 if let Some(markdown) = &self.settings_markdown {
778 let style = settings_markdown_style(_window, cx);
779 content = content.child(
780 div()
781 .p_2()
782 .rounded_md()
783 .bg(cx.theme().colors().surface_background)
784 .child(MarkdownElement::new(markdown.clone(), style)),
785 );
786 }
787
788 // Render env var checkbox if the extension specifies an env var
789 if let Some(auth_config) = &self.auth_config {
790 if let Some(env_var_name) = &auth_config.env_var {
791 let env_var_name = env_var_name.clone();
792 let checkbox_label =
793 format!("Read API key from {} environment variable", env_var_name);
794
795 content = content.child(
796 h_flex()
797 .gap_2()
798 .child(
799 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
800 .on_click(cx.listener(|this, _, _window, cx| {
801 this.toggle_env_var_permission(cx);
802 })),
803 )
804 .child(Label::new(checkbox_label).size(LabelSize::Small)),
805 );
806
807 // Show status if env var is allowed
808 if env_var_allowed {
809 if api_key_from_env {
810 content = content.child(
811 h_flex()
812 .gap_2()
813 .child(
814 ui::Icon::new(ui::IconName::Check)
815 .color(Color::Success)
816 .size(ui::IconSize::Small),
817 )
818 .child(
819 Label::new(format!("API key loaded from {}", env_var_name))
820 .color(Color::Success),
821 ),
822 );
823 return content.into_any_element();
824 } else {
825 content = content.child(
826 h_flex()
827 .gap_2()
828 .child(
829 ui::Icon::new(ui::IconName::Warning)
830 .color(Color::Warning)
831 .size(ui::IconSize::Small),
832 )
833 .child(
834 Label::new(format!(
835 "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
836 env_var_name
837 ))
838 .color(Color::Warning)
839 .size(LabelSize::Small),
840 ),
841 );
842 }
843 }
844 }
845 }
846
847 // If authenticated, show success state with sign out option
848 if is_authenticated && !api_key_from_env {
849 let reset_label = if has_oauth && !has_api_key {
850 "Sign Out"
851 } else {
852 "Reset Credentials"
853 };
854
855 let status_label = if has_oauth && !has_api_key {
856 "Signed in"
857 } else {
858 "Authenticated"
859 };
860
861 content = content.child(
862 v_flex()
863 .gap_2()
864 .child(
865 h_flex()
866 .gap_2()
867 .child(
868 ui::Icon::new(ui::IconName::Check)
869 .color(Color::Success)
870 .size(ui::IconSize::Small),
871 )
872 .child(Label::new(status_label).color(Color::Success)),
873 )
874 .child(
875 ui::Button::new("reset-credentials", reset_label)
876 .style(ui::ButtonStyle::Subtle)
877 .on_click(cx.listener(|this, _, window, cx| {
878 this.reset_api_key(window, cx);
879 })),
880 ),
881 );
882
883 return content.into_any_element();
884 }
885
886 // Not authenticated - show available auth options
887 if !api_key_from_env {
888 // Render OAuth sign-in button if configured
889 if has_oauth {
890 let oauth_config = self.oauth_config();
891 let button_label = oauth_config
892 .and_then(|c| c.sign_in_button_label.clone())
893 .unwrap_or_else(|| "Sign In".to_string());
894
895 let oauth_in_progress = self.oauth_in_progress;
896
897 let oauth_error = self.oauth_error.clone();
898
899 content = content.child(
900 v_flex()
901 .gap_2()
902 .child(
903 ui::Button::new("oauth-sign-in", button_label)
904 .style(ui::ButtonStyle::Filled)
905 .disabled(oauth_in_progress)
906 .on_click(cx.listener(|this, _, _window, cx| {
907 this.start_oauth_sign_in(cx);
908 })),
909 )
910 .when(oauth_in_progress, |this| {
911 let user_code = self.device_user_code.clone();
912 this.child(
913 v_flex()
914 .gap_1()
915 .when_some(user_code, |this, code| {
916 let copied = cx
917 .read_from_clipboard()
918 .map(|item| item.text().as_ref() == Some(&code))
919 .unwrap_or(false);
920 let code_for_click = code.clone();
921 this.child(
922 h_flex()
923 .gap_1()
924 .child(
925 Label::new("Enter code:")
926 .size(LabelSize::Small)
927 .color(Color::Muted),
928 )
929 .child(
930 h_flex()
931 .gap_1()
932 .px_1()
933 .border_1()
934 .border_color(cx.theme().colors().border)
935 .rounded_sm()
936 .cursor_pointer()
937 .on_mouse_down(
938 MouseButton::Left,
939 move |_, window, cx| {
940 cx.write_to_clipboard(
941 ClipboardItem::new_string(
942 code_for_click.clone(),
943 ),
944 );
945 window.refresh();
946 },
947 )
948 .child(
949 Label::new(code)
950 .size(LabelSize::Small)
951 .color(Color::Accent),
952 )
953 .child(
954 ui::Icon::new(if copied {
955 ui::IconName::Check
956 } else {
957 ui::IconName::Copy
958 })
959 .size(ui::IconSize::Small)
960 .color(if copied {
961 Color::Success
962 } else {
963 Color::Muted
964 }),
965 ),
966 ),
967 )
968 })
969 .child(
970 Label::new("Waiting for authorization in browser...")
971 .size(LabelSize::Small)
972 .color(Color::Muted),
973 ),
974 )
975 })
976 .when_some(oauth_error, |this, error| {
977 this.child(
978 v_flex()
979 .gap_1()
980 .child(
981 h_flex()
982 .gap_2()
983 .child(
984 ui::Icon::new(ui::IconName::Warning)
985 .color(Color::Error)
986 .size(ui::IconSize::Small),
987 )
988 .child(
989 Label::new("Authentication failed")
990 .color(Color::Error)
991 .size(LabelSize::Small),
992 ),
993 )
994 .child(
995 div().pl_6().child(
996 Label::new(error)
997 .color(Color::Error)
998 .size(LabelSize::Small),
999 ),
1000 ),
1001 )
1002 }),
1003 );
1004 }
1005
1006 // Render API key input if configured (and we have both options, show a separator)
1007 if has_api_key {
1008 if has_oauth {
1009 content = content.child(
1010 h_flex()
1011 .gap_2()
1012 .items_center()
1013 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1014 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1015 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1016 );
1017 }
1018
1019 let credential_label = self
1020 .auth_config
1021 .as_ref()
1022 .and_then(|c| c.credential_label.clone())
1023 .unwrap_or_else(|| "API Key".to_string());
1024
1025 content = content.child(
1026 v_flex()
1027 .gap_2()
1028 .on_action(cx.listener(Self::save_api_key))
1029 .child(
1030 Label::new(credential_label)
1031 .size(LabelSize::Small)
1032 .color(Color::Muted),
1033 )
1034 .child(self.api_key_editor.clone())
1035 .child(
1036 Label::new("Enter your API key and press Enter to save")
1037 .size(LabelSize::Small)
1038 .color(Color::Muted),
1039 ),
1040 );
1041 }
1042 }
1043
1044 content.into_any_element()
1045 }
1046}
1047
1048impl Focusable for ExtensionProviderConfigurationView {
1049 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1050 self.api_key_editor.focus_handle(cx)
1051 }
1052}
1053
1054fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1055 let theme_settings = ThemeSettings::get_global(cx);
1056 let colors = cx.theme().colors();
1057 let mut text_style = window.text_style();
1058 text_style.refine(&TextStyleRefinement {
1059 font_family: Some(theme_settings.ui_font.family.clone()),
1060 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1061 font_features: Some(theme_settings.ui_font.features.clone()),
1062 color: Some(colors.text),
1063 ..Default::default()
1064 });
1065
1066 MarkdownStyle {
1067 base_text_style: text_style,
1068 selection_background_color: colors.element_selection_background,
1069 inline_code: TextStyleRefinement {
1070 background_color: Some(colors.editor_background),
1071 ..Default::default()
1072 },
1073 link: TextStyleRefinement {
1074 color: Some(colors.text_accent),
1075 underline: Some(UnderlineStyle {
1076 color: Some(colors.text_accent.opacity(0.5)),
1077 thickness: px(1.),
1078 ..Default::default()
1079 }),
1080 ..Default::default()
1081 },
1082 syntax: cx.theme().syntax().clone(),
1083 ..Default::default()
1084 }
1085}
1086
1087/// An extension-based language model.
1088pub struct ExtensionLanguageModel {
1089 extension: WasmExtension,
1090 model_info: LlmModelInfo,
1091 provider_id: LanguageModelProviderId,
1092 provider_name: LanguageModelProviderName,
1093 provider_info: LlmProviderInfo,
1094}
1095
1096impl LanguageModel for ExtensionLanguageModel {
1097 fn id(&self) -> LanguageModelId {
1098 LanguageModelId::from(self.model_info.id.clone())
1099 }
1100
1101 fn name(&self) -> LanguageModelName {
1102 LanguageModelName::from(self.model_info.name.clone())
1103 }
1104
1105 fn provider_id(&self) -> LanguageModelProviderId {
1106 self.provider_id.clone()
1107 }
1108
1109 fn provider_name(&self) -> LanguageModelProviderName {
1110 self.provider_name.clone()
1111 }
1112
1113 fn telemetry_id(&self) -> String {
1114 format!("extension-{}", self.model_info.id)
1115 }
1116
1117 fn supports_images(&self) -> bool {
1118 self.model_info.capabilities.supports_images
1119 }
1120
1121 fn supports_tools(&self) -> bool {
1122 self.model_info.capabilities.supports_tools
1123 }
1124
1125 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1126 match choice {
1127 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1128 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1129 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1130 }
1131 }
1132
1133 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1134 match self.model_info.capabilities.tool_input_format {
1135 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1136 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1137 }
1138 }
1139
1140 fn max_token_count(&self) -> u64 {
1141 self.model_info.max_token_count
1142 }
1143
1144 fn max_output_tokens(&self) -> Option<u64> {
1145 self.model_info.max_output_tokens
1146 }
1147
1148 fn count_tokens(
1149 &self,
1150 request: LanguageModelRequest,
1151 cx: &App,
1152 ) -> BoxFuture<'static, Result<u64>> {
1153 let extension = self.extension.clone();
1154 let provider_id = self.provider_info.id.clone();
1155 let model_id = self.model_info.id.clone();
1156
1157 let wit_request = convert_request_to_wit(request);
1158
1159 cx.background_spawn(async move {
1160 extension
1161 .call({
1162 let provider_id = provider_id.clone();
1163 let model_id = model_id.clone();
1164 let wit_request = wit_request.clone();
1165 |ext, store| {
1166 async move {
1167 let count = ext
1168 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1169 .await?
1170 .map_err(|e| anyhow!("{}", e))?;
1171 Ok(count)
1172 }
1173 .boxed()
1174 }
1175 })
1176 .await?
1177 })
1178 .boxed()
1179 }
1180
1181 fn stream_completion(
1182 &self,
1183 request: LanguageModelRequest,
1184 _cx: &AsyncApp,
1185 ) -> BoxFuture<
1186 'static,
1187 Result<
1188 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1189 LanguageModelCompletionError,
1190 >,
1191 > {
1192 let extension = self.extension.clone();
1193 let provider_id = self.provider_info.id.clone();
1194 let model_id = self.model_info.id.clone();
1195
1196 let wit_request = convert_request_to_wit(request);
1197
1198 async move {
1199 // Start the stream
1200 let stream_id_result = extension
1201 .call({
1202 let provider_id = provider_id.clone();
1203 let model_id = model_id.clone();
1204 let wit_request = wit_request.clone();
1205 |ext, store| {
1206 async move {
1207 let id = ext
1208 .call_llm_stream_completion_start(
1209 store,
1210 &provider_id,
1211 &model_id,
1212 &wit_request,
1213 )
1214 .await?
1215 .map_err(|e| anyhow!("{}", e))?;
1216 Ok(id)
1217 }
1218 .boxed()
1219 }
1220 })
1221 .await;
1222
1223 let stream_id = stream_id_result
1224 .map_err(LanguageModelCompletionError::Other)?
1225 .map_err(LanguageModelCompletionError::Other)?;
1226
1227 // Create a stream that polls for events
1228 let stream = futures::stream::unfold(
1229 (extension.clone(), stream_id, false),
1230 move |(extension, stream_id, done)| async move {
1231 if done {
1232 return None;
1233 }
1234
1235 let result = extension
1236 .call({
1237 let stream_id = stream_id.clone();
1238 |ext, store| {
1239 async move {
1240 let event = ext
1241 .call_llm_stream_completion_next(store, &stream_id)
1242 .await?
1243 .map_err(|e| anyhow!("{}", e))?;
1244 Ok(event)
1245 }
1246 .boxed()
1247 }
1248 })
1249 .await
1250 .and_then(|inner| inner);
1251
1252 match result {
1253 Ok(Some(event)) => {
1254 let converted = convert_completion_event(event);
1255 let is_done =
1256 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1257 Some((converted, (extension, stream_id, is_done)))
1258 }
1259 Ok(None) => {
1260 // Stream complete, close it
1261 let _ = extension
1262 .call({
1263 let stream_id = stream_id.clone();
1264 |ext, store| {
1265 async move {
1266 ext.call_llm_stream_completion_close(store, &stream_id)
1267 .await?;
1268 Ok::<(), anyhow::Error>(())
1269 }
1270 .boxed()
1271 }
1272 })
1273 .await;
1274 None
1275 }
1276 Err(e) => Some((
1277 Err(LanguageModelCompletionError::Other(e)),
1278 (extension, stream_id, true),
1279 )),
1280 }
1281 },
1282 );
1283
1284 Ok(stream.boxed())
1285 }
1286 .boxed()
1287 }
1288
1289 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1290 // Extensions can implement this via llm_cache_configuration
1291 None
1292 }
1293}
1294
1295fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1296 use language_model::{MessageContent, Role};
1297
1298 let messages: Vec<LlmRequestMessage> = request
1299 .messages
1300 .into_iter()
1301 .map(|msg| {
1302 let role = match msg.role {
1303 Role::User => LlmMessageRole::User,
1304 Role::Assistant => LlmMessageRole::Assistant,
1305 Role::System => LlmMessageRole::System,
1306 };
1307
1308 let content: Vec<LlmMessageContent> = msg
1309 .content
1310 .into_iter()
1311 .map(|c| match c {
1312 MessageContent::Text(text) => LlmMessageContent::Text(text),
1313 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1314 source: image.source.to_string(),
1315 width: Some(image.size.width.0 as u32),
1316 height: Some(image.size.height.0 as u32),
1317 }),
1318 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1319 id: tool_use.id.to_string(),
1320 name: tool_use.name.to_string(),
1321 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1322 thought_signature: tool_use.thought_signature,
1323 }),
1324 MessageContent::ToolResult(tool_result) => {
1325 let content = match tool_result.content {
1326 language_model::LanguageModelToolResultContent::Text(text) => {
1327 LlmToolResultContent::Text(text.to_string())
1328 }
1329 language_model::LanguageModelToolResultContent::Image(image) => {
1330 LlmToolResultContent::Image(LlmImageData {
1331 source: image.source.to_string(),
1332 width: Some(image.size.width.0 as u32),
1333 height: Some(image.size.height.0 as u32),
1334 })
1335 }
1336 };
1337 LlmMessageContent::ToolResult(LlmToolResult {
1338 tool_use_id: tool_result.tool_use_id.to_string(),
1339 tool_name: tool_result.tool_name.to_string(),
1340 is_error: tool_result.is_error,
1341 content,
1342 })
1343 }
1344 MessageContent::Thinking { text, signature } => {
1345 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1346 }
1347 MessageContent::RedactedThinking(data) => {
1348 LlmMessageContent::RedactedThinking(data)
1349 }
1350 })
1351 .collect();
1352
1353 LlmRequestMessage {
1354 role,
1355 content,
1356 cache: msg.cache,
1357 }
1358 })
1359 .collect();
1360
1361 let tools: Vec<LlmToolDefinition> = request
1362 .tools
1363 .into_iter()
1364 .map(|tool| LlmToolDefinition {
1365 name: tool.name,
1366 description: tool.description,
1367 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1368 })
1369 .collect();
1370
1371 let tool_choice = request.tool_choice.map(|tc| match tc {
1372 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1373 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1374 LanguageModelToolChoice::None => LlmToolChoice::None,
1375 });
1376
1377 LlmCompletionRequest {
1378 messages,
1379 tools,
1380 tool_choice,
1381 stop_sequences: request.stop,
1382 temperature: request.temperature,
1383 thinking_allowed: false,
1384 max_tokens: None,
1385 }
1386}
1387
1388fn convert_completion_event(
1389 event: LlmCompletionEvent,
1390) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1391 match event {
1392 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1393 message_id: String::new(),
1394 }),
1395 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1396 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1397 text: thinking.text,
1398 signature: thinking.signature,
1399 }),
1400 LlmCompletionEvent::RedactedThinking(data) => {
1401 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1402 }
1403 LlmCompletionEvent::ToolUse(tool_use) => {
1404 let raw_input = tool_use.input.clone();
1405 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1406 Ok(LanguageModelCompletionEvent::ToolUse(
1407 LanguageModelToolUse {
1408 id: LanguageModelToolUseId::from(tool_use.id),
1409 name: tool_use.name.into(),
1410 raw_input,
1411 input,
1412 is_input_complete: true,
1413 thought_signature: tool_use.thought_signature,
1414 },
1415 ))
1416 }
1417 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1418 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1419 id: LanguageModelToolUseId::from(error.id),
1420 tool_name: error.tool_name.into(),
1421 raw_input: error.raw_input.into(),
1422 json_parse_error: error.error,
1423 })
1424 }
1425 LlmCompletionEvent::Stop(reason) => {
1426 let stop_reason = match reason {
1427 LlmStopReason::EndTurn => StopReason::EndTurn,
1428 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1429 LlmStopReason::ToolUse => StopReason::ToolUse,
1430 LlmStopReason::Refusal => StopReason::Refusal,
1431 };
1432 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1433 }
1434 LlmCompletionEvent::Usage(usage) => {
1435 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1436 input_tokens: usage.input_tokens,
1437 output_tokens: usage.output_tokens,
1438 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1439 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1440 }))
1441 }
1442 LlmCompletionEvent::ReasoningDetails(json) => {
1443 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1444 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1445 ))
1446 }
1447 }
1448}