1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::{LanguageModelAuthConfig, OAuthConfig};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
20 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(self.provider_info.name.clone())
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 self.state.read(cx).is_authenticated
183 }
184
185 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
186 let extension = self.extension.clone();
187 let provider_id = self.provider_info.id.clone();
188 let state = self.state.clone();
189
190 cx.spawn(async move |cx| {
191 let result = extension
192 .call(|extension, store| {
193 async move {
194 extension
195 .call_llm_provider_authenticate(store, &provider_id)
196 .await
197 }
198 .boxed()
199 })
200 .await;
201
202 match result {
203 Ok(Ok(Ok(()))) => {
204 cx.update(|cx| {
205 state.update(cx, |state, _| {
206 state.is_authenticated = true;
207 });
208 })?;
209 Ok(())
210 }
211 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
212 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
213 Err(e) => Err(AuthenticateError::Other(e)),
214 }
215 })
216 }
217
218 fn configuration_view(
219 &self,
220 _target_agent: ConfigurationViewTargetAgent,
221 window: &mut Window,
222 cx: &mut App,
223 ) -> AnyView {
224 let credential_key = self.credential_key();
225 let extension = self.extension.clone();
226 let extension_provider_id = self.provider_info.id.clone();
227 let full_provider_id = self.provider_id_string();
228 let state = self.state.clone();
229 let auth_config = self.auth_config.clone();
230
231 cx.new(|cx| {
232 ExtensionProviderConfigurationView::new(
233 credential_key,
234 extension,
235 extension_provider_id,
236 full_provider_id,
237 auth_config,
238 state,
239 window,
240 cx,
241 )
242 })
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 let extension = self.extension.clone();
248 let provider_id = self.provider_info.id.clone();
249 let state = self.state.clone();
250 let credential_key = self.credential_key();
251
252 let credentials_provider = <dyn CredentialsProvider>::global(cx);
253
254 cx.spawn(async move |cx| {
255 // Delete from system keychain
256 credentials_provider
257 .delete_credentials(&credential_key, cx)
258 .await
259 .log_err();
260
261 // Call extension's reset_credentials
262 let result = extension
263 .call(|extension, store| {
264 async move {
265 extension
266 .call_llm_provider_reset_credentials(store, &provider_id)
267 .await
268 }
269 .boxed()
270 })
271 .await;
272
273 // Update state
274 cx.update(|cx| {
275 state.update(cx, |state, _| {
276 state.is_authenticated = false;
277 });
278 })?;
279
280 match result {
281 Ok(Ok(Ok(()))) => Ok(()),
282 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
283 Ok(Err(e)) => Err(e),
284 Err(e) => Err(e),
285 }
286 })
287 }
288}
289
290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
291 type ObservableEntity = ExtensionLlmProviderState;
292
293 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
294 Some(self.state.clone())
295 }
296
297 fn subscribe<T: 'static>(
298 &self,
299 cx: &mut Context<T>,
300 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
301 ) -> Option<Subscription> {
302 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
303 }
304}
305
306/// Configuration view for extension-based LLM providers.
307struct ExtensionProviderConfigurationView {
308 credential_key: String,
309 extension: WasmExtension,
310 extension_provider_id: String,
311 full_provider_id: String,
312 auth_config: Option<LanguageModelAuthConfig>,
313 state: Entity<ExtensionLlmProviderState>,
314 settings_markdown: Option<Entity<Markdown>>,
315 api_key_editor: Entity<Editor>,
316 loading_settings: bool,
317 loading_credentials: bool,
318 oauth_in_progress: bool,
319 oauth_error: Option<String>,
320 device_user_code: Option<String>,
321 _subscriptions: Vec<Subscription>,
322}
323
324impl ExtensionProviderConfigurationView {
325 fn new(
326 credential_key: String,
327 extension: WasmExtension,
328 extension_provider_id: String,
329 full_provider_id: String,
330 auth_config: Option<LanguageModelAuthConfig>,
331 state: Entity<ExtensionLlmProviderState>,
332 window: &mut Window,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 // Subscribe to state changes
336 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
337 cx.notify();
338 });
339
340 // Create API key editor
341 let api_key_editor = cx.new(|cx| {
342 let mut editor = Editor::single_line(window, cx);
343 editor.set_placeholder_text("Enter API key...", window, cx);
344 editor
345 });
346
347 let mut this = Self {
348 credential_key,
349 extension,
350 extension_provider_id,
351 full_provider_id,
352 auth_config,
353 state,
354 settings_markdown: None,
355 api_key_editor,
356 loading_settings: true,
357 loading_credentials: true,
358 oauth_in_progress: false,
359 oauth_error: None,
360 device_user_code: None,
361 _subscriptions: vec![state_subscription],
362 };
363
364 // Load settings text from extension
365 this.load_settings_text(cx);
366
367 // Load existing credentials
368 this.load_credentials(cx);
369
370 this
371 }
372
373 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
374 let extension = self.extension.clone();
375 let provider_id = self.extension_provider_id.clone();
376
377 cx.spawn(async move |this, cx| {
378 let result = extension
379 .call({
380 let provider_id = provider_id.clone();
381 |ext, store| {
382 async move {
383 ext.call_llm_provider_settings_markdown(store, &provider_id)
384 .await
385 }
386 .boxed()
387 }
388 })
389 .await;
390
391 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
392
393 this.update(cx, |this, cx| {
394 this.loading_settings = false;
395 if let Some(text) = settings_text {
396 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
397 this.settings_markdown = Some(markdown);
398 }
399 cx.notify();
400 })
401 .log_err();
402 })
403 .detach();
404 }
405
406 fn load_credentials(&mut self, cx: &mut Context<Self>) {
407 let credential_key = self.credential_key.clone();
408 let credentials_provider = <dyn CredentialsProvider>::global(cx);
409 let state = self.state.clone();
410
411 // Check if we should use env var (already set in state during provider construction)
412 let api_key_from_env = self.state.read(cx).api_key_from_env;
413
414 cx.spawn(async move |this, cx| {
415 // If using env var, we're already authenticated
416 if api_key_from_env {
417 this.update(cx, |this, cx| {
418 this.loading_credentials = false;
419 cx.notify();
420 })
421 .log_err();
422 return;
423 }
424
425 let credentials = credentials_provider
426 .read_credentials(&credential_key, cx)
427 .await
428 .log_err()
429 .flatten();
430
431 let has_credentials = credentials.is_some();
432
433 // Update authentication state based on stored credentials
434 let _ = cx.update(|cx| {
435 state.update(cx, |state, cx| {
436 state.is_authenticated = has_credentials;
437 cx.notify();
438 });
439 });
440
441 this.update(cx, |this, cx| {
442 this.loading_credentials = false;
443 cx.notify();
444 })
445 .log_err();
446 })
447 .detach();
448 }
449
450 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
451 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
452 let env_var_name = match &self.auth_config {
453 Some(config) => config.env_var.clone(),
454 None => return,
455 };
456
457 let state = self.state.clone();
458 let currently_allowed = self.state.read(cx).env_var_allowed;
459
460 // Update settings file
461 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
462 let providers = settings
463 .extension
464 .allowed_env_var_providers
465 .get_or_insert_with(Vec::new);
466
467 if currently_allowed {
468 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
469 } else {
470 if !providers
471 .iter()
472 .any(|id| id.as_ref() == full_provider_id.as_ref())
473 {
474 providers.push(full_provider_id.clone());
475 }
476 }
477 });
478
479 // Update local state
480 let new_allowed = !currently_allowed;
481 let new_from_env = if new_allowed {
482 if let Some(var_name) = &env_var_name {
483 if let Ok(value) = std::env::var(var_name) {
484 !value.is_empty()
485 } else {
486 false
487 }
488 } else {
489 false
490 }
491 } else {
492 false
493 };
494
495 state.update(cx, |state, cx| {
496 state.env_var_allowed = new_allowed;
497 state.api_key_from_env = new_from_env;
498 if new_from_env {
499 state.is_authenticated = true;
500 }
501 cx.notify();
502 });
503
504 // If env var is being enabled, clear any stored keychain credentials
505 // so there's only one source of truth for the API key
506 if new_allowed {
507 let credential_key = self.credential_key.clone();
508 let credentials_provider = <dyn CredentialsProvider>::global(cx);
509 cx.spawn(async move |_this, cx| {
510 credentials_provider
511 .delete_credentials(&credential_key, cx)
512 .await
513 .log_err();
514 })
515 .detach();
516 }
517
518 // If env var is being disabled, reload credentials from keychain
519 if !new_allowed {
520 self.reload_keychain_credentials(cx);
521 }
522
523 cx.notify();
524 }
525
526 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
527 let credential_key = self.credential_key.clone();
528 let credentials_provider = <dyn CredentialsProvider>::global(cx);
529 let state = self.state.clone();
530
531 cx.spawn(async move |_this, cx| {
532 let credentials = credentials_provider
533 .read_credentials(&credential_key, cx)
534 .await
535 .log_err()
536 .flatten();
537
538 let has_credentials = credentials.is_some();
539
540 let _ = cx.update(|cx| {
541 state.update(cx, |state, cx| {
542 state.is_authenticated = has_credentials;
543 cx.notify();
544 });
545 });
546 })
547 .detach();
548 }
549
550 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
551 let api_key = self.api_key_editor.read(cx).text(cx);
552 if api_key.is_empty() {
553 return;
554 }
555
556 // Clear the editor
557 self.api_key_editor
558 .update(cx, |editor, cx| editor.set_text("", window, cx));
559
560 let credential_key = self.credential_key.clone();
561 let credentials_provider = <dyn CredentialsProvider>::global(cx);
562 let state = self.state.clone();
563
564 cx.spawn(async move |_this, cx| {
565 // Store in system keychain
566 credentials_provider
567 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
568 .await
569 .log_err();
570
571 // Update state to authenticated
572 let _ = cx.update(|cx| {
573 state.update(cx, |state, cx| {
574 state.is_authenticated = true;
575 cx.notify();
576 });
577 });
578 })
579 .detach();
580 }
581
582 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
583 // Clear the editor
584 self.api_key_editor
585 .update(cx, |editor, cx| editor.set_text("", window, cx));
586
587 let credential_key = self.credential_key.clone();
588 let credentials_provider = <dyn CredentialsProvider>::global(cx);
589 let state = self.state.clone();
590
591 cx.spawn(async move |_this, cx| {
592 // Delete from system keychain
593 credentials_provider
594 .delete_credentials(&credential_key, cx)
595 .await
596 .log_err();
597
598 // Update state to unauthenticated
599 let _ = cx.update(|cx| {
600 state.update(cx, |state, cx| {
601 state.is_authenticated = false;
602 cx.notify();
603 });
604 });
605 })
606 .detach();
607 }
608
609 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
610 if self.oauth_in_progress {
611 return;
612 }
613
614 self.oauth_in_progress = true;
615 self.oauth_error = None;
616 self.device_user_code = None;
617 cx.notify();
618
619 let extension = self.extension.clone();
620 let provider_id = self.extension_provider_id.clone();
621 let state = self.state.clone();
622
623 cx.spawn(async move |this, cx| {
624 // Step 1: Start device flow - opens browser and returns user code
625 let start_result = extension
626 .call({
627 let provider_id = provider_id.clone();
628 |ext, store| {
629 async move {
630 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
631 .await
632 }
633 .boxed()
634 }
635 })
636 .await;
637
638 let user_code = match start_result {
639 Ok(Ok(Ok(code))) => code,
640 Ok(Ok(Err(e))) => {
641 log::error!("Device flow start failed: {}", e);
642 this.update(cx, |this, cx| {
643 this.oauth_in_progress = false;
644 this.oauth_error = Some(e);
645 cx.notify();
646 })
647 .log_err();
648 return;
649 }
650 Ok(Err(e)) | Err(e) => {
651 log::error!("Device flow start error: {}", e);
652 this.update(cx, |this, cx| {
653 this.oauth_in_progress = false;
654 this.oauth_error = Some(e.to_string());
655 cx.notify();
656 })
657 .log_err();
658 return;
659 }
660 };
661
662 // Update UI to show the user code before polling
663 this.update(cx, |this, cx| {
664 this.device_user_code = Some(user_code);
665 cx.notify();
666 })
667 .log_err();
668
669 // Step 2: Poll for authentication completion
670 let poll_result = extension
671 .call({
672 let provider_id = provider_id.clone();
673 |ext, store| {
674 async move {
675 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
676 .await
677 }
678 .boxed()
679 }
680 })
681 .await;
682
683 let error_message = match poll_result {
684 Ok(Ok(Ok(()))) => {
685 let _ = cx.update(|cx| {
686 state.update(cx, |state, cx| {
687 state.is_authenticated = true;
688 cx.notify();
689 });
690 });
691 None
692 }
693 Ok(Ok(Err(e))) => {
694 log::error!("Device flow poll failed: {}", e);
695 Some(e)
696 }
697 Ok(Err(e)) | Err(e) => {
698 log::error!("Device flow poll error: {}", e);
699 Some(e.to_string())
700 }
701 };
702
703 this.update(cx, |this, cx| {
704 this.oauth_in_progress = false;
705 this.oauth_error = error_message;
706 this.device_user_code = None;
707 cx.notify();
708 })
709 .log_err();
710 })
711 .detach();
712 }
713
714 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
715 self.state.read(cx).is_authenticated
716 }
717
718 fn has_oauth_config(&self) -> bool {
719 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
720 }
721
722 fn oauth_config(&self) -> Option<&OAuthConfig> {
723 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
724 }
725
726 fn has_api_key_config(&self) -> bool {
727 // API key is available if there's a credential_label or no oauth-only config
728 self.auth_config
729 .as_ref()
730 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
731 .unwrap_or(true)
732 }
733}
734
735impl gpui::Render for ExtensionProviderConfigurationView {
736 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
737 let is_loading = self.loading_settings || self.loading_credentials;
738 let is_authenticated = self.is_authenticated(cx);
739 let env_var_allowed = self.state.read(cx).env_var_allowed;
740 let api_key_from_env = self.state.read(cx).api_key_from_env;
741 let has_oauth = self.has_oauth_config();
742 let has_api_key = self.has_api_key_config();
743
744 if is_loading {
745 return v_flex()
746 .gap_2()
747 .child(Label::new("Loading...").color(Color::Muted))
748 .into_any_element();
749 }
750
751 let mut content = v_flex().gap_4().size_full();
752
753 // Render settings markdown if available
754 if let Some(markdown) = &self.settings_markdown {
755 let style = settings_markdown_style(_window, cx);
756 content = content.child(
757 div()
758 .p_2()
759 .rounded_md()
760 .bg(cx.theme().colors().surface_background)
761 .child(MarkdownElement::new(markdown.clone(), style)),
762 );
763 }
764
765 // Render env var checkbox if the extension specifies an env var
766 if let Some(auth_config) = &self.auth_config {
767 if let Some(env_var_name) = &auth_config.env_var {
768 let env_var_name = env_var_name.clone();
769 let checkbox_label =
770 format!("Read API key from {} environment variable", env_var_name);
771
772 content = content.child(
773 h_flex()
774 .gap_2()
775 .child(
776 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
777 .on_click(cx.listener(|this, _, _window, cx| {
778 this.toggle_env_var_permission(cx);
779 })),
780 )
781 .child(Label::new(checkbox_label).size(LabelSize::Small)),
782 );
783
784 // Show status if env var is allowed
785 if env_var_allowed {
786 if api_key_from_env {
787 content = content.child(
788 h_flex()
789 .gap_2()
790 .child(
791 ui::Icon::new(ui::IconName::Check)
792 .color(Color::Success)
793 .size(ui::IconSize::Small),
794 )
795 .child(
796 Label::new(format!("API key loaded from {}", env_var_name))
797 .color(Color::Success),
798 ),
799 );
800 return content.into_any_element();
801 } else {
802 content = content.child(
803 h_flex()
804 .gap_2()
805 .child(
806 ui::Icon::new(ui::IconName::Warning)
807 .color(Color::Warning)
808 .size(ui::IconSize::Small),
809 )
810 .child(
811 Label::new(format!(
812 "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
813 env_var_name
814 ))
815 .color(Color::Warning)
816 .size(LabelSize::Small),
817 ),
818 );
819 }
820 }
821 }
822 }
823
824 // If authenticated, show success state with sign out option
825 if is_authenticated && !api_key_from_env {
826 let reset_label = if has_oauth && !has_api_key {
827 "Sign Out"
828 } else {
829 "Reset Credentials"
830 };
831
832 let status_label = if has_oauth && !has_api_key {
833 "Signed in"
834 } else {
835 "Authenticated"
836 };
837
838 content = content.child(
839 v_flex()
840 .gap_2()
841 .child(
842 h_flex()
843 .gap_2()
844 .child(
845 ui::Icon::new(ui::IconName::Check)
846 .color(Color::Success)
847 .size(ui::IconSize::Small),
848 )
849 .child(Label::new(status_label).color(Color::Success)),
850 )
851 .child(
852 ui::Button::new("reset-credentials", reset_label)
853 .style(ui::ButtonStyle::Subtle)
854 .on_click(cx.listener(|this, _, window, cx| {
855 this.reset_api_key(window, cx);
856 })),
857 ),
858 );
859
860 return content.into_any_element();
861 }
862
863 // Not authenticated - show available auth options
864 if !api_key_from_env {
865 // Render OAuth sign-in button if configured
866 if has_oauth {
867 let oauth_config = self.oauth_config();
868 let button_label = oauth_config
869 .and_then(|c| c.sign_in_button_label.clone())
870 .unwrap_or_else(|| "Sign In".to_string());
871
872 let oauth_in_progress = self.oauth_in_progress;
873
874 let oauth_error = self.oauth_error.clone();
875
876 content = content.child(
877 v_flex()
878 .gap_2()
879 .child(
880 ui::Button::new("oauth-sign-in", button_label)
881 .style(ui::ButtonStyle::Filled)
882 .disabled(oauth_in_progress)
883 .on_click(cx.listener(|this, _, _window, cx| {
884 this.start_oauth_sign_in(cx);
885 })),
886 )
887 .when(oauth_in_progress, |this| {
888 let user_code = self.device_user_code.clone();
889 this.child(
890 v_flex()
891 .gap_1()
892 .when_some(user_code, |this, code| {
893 let copied = cx
894 .read_from_clipboard()
895 .map(|item| item.text().as_ref() == Some(&code))
896 .unwrap_or(false);
897 let code_for_click = code.clone();
898 this.child(
899 h_flex()
900 .gap_1()
901 .child(
902 Label::new("Enter code:")
903 .size(LabelSize::Small)
904 .color(Color::Muted),
905 )
906 .child(
907 h_flex()
908 .gap_1()
909 .px_1()
910 .border_1()
911 .border_color(cx.theme().colors().border)
912 .rounded_sm()
913 .cursor_pointer()
914 .on_mouse_down(
915 MouseButton::Left,
916 move |_, window, cx| {
917 cx.write_to_clipboard(
918 ClipboardItem::new_string(
919 code_for_click.clone(),
920 ),
921 );
922 window.refresh();
923 },
924 )
925 .child(
926 Label::new(code)
927 .size(LabelSize::Small)
928 .color(Color::Accent),
929 )
930 .child(
931 ui::Icon::new(if copied {
932 ui::IconName::Check
933 } else {
934 ui::IconName::Copy
935 })
936 .size(ui::IconSize::Small)
937 .color(if copied {
938 Color::Success
939 } else {
940 Color::Muted
941 }),
942 ),
943 ),
944 )
945 })
946 .child(
947 Label::new("Waiting for authorization in browser...")
948 .size(LabelSize::Small)
949 .color(Color::Muted),
950 ),
951 )
952 })
953 .when_some(oauth_error, |this, error| {
954 this.child(
955 v_flex()
956 .gap_1()
957 .child(
958 h_flex()
959 .gap_2()
960 .child(
961 ui::Icon::new(ui::IconName::Warning)
962 .color(Color::Error)
963 .size(ui::IconSize::Small),
964 )
965 .child(
966 Label::new("Authentication failed")
967 .color(Color::Error)
968 .size(LabelSize::Small),
969 ),
970 )
971 .child(
972 div().pl_6().child(
973 Label::new(error)
974 .color(Color::Error)
975 .size(LabelSize::Small),
976 ),
977 ),
978 )
979 }),
980 );
981 }
982
983 // Render API key input if configured (and we have both options, show a separator)
984 if has_api_key {
985 if has_oauth {
986 content = content.child(
987 h_flex()
988 .gap_2()
989 .items_center()
990 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
991 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
992 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
993 );
994 }
995
996 let credential_label = self
997 .auth_config
998 .as_ref()
999 .and_then(|c| c.credential_label.clone())
1000 .unwrap_or_else(|| "API Key".to_string());
1001
1002 content = content.child(
1003 v_flex()
1004 .gap_2()
1005 .on_action(cx.listener(Self::save_api_key))
1006 .child(
1007 Label::new(credential_label)
1008 .size(LabelSize::Small)
1009 .color(Color::Muted),
1010 )
1011 .child(self.api_key_editor.clone())
1012 .child(
1013 Label::new("Enter your API key and press Enter to save")
1014 .size(LabelSize::Small)
1015 .color(Color::Muted),
1016 ),
1017 );
1018 }
1019 }
1020
1021 content.into_any_element()
1022 }
1023}
1024
1025impl Focusable for ExtensionProviderConfigurationView {
1026 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1027 self.api_key_editor.focus_handle(cx)
1028 }
1029}
1030
1031fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1032 let theme_settings = ThemeSettings::get_global(cx);
1033 let colors = cx.theme().colors();
1034 let mut text_style = window.text_style();
1035 text_style.refine(&TextStyleRefinement {
1036 font_family: Some(theme_settings.ui_font.family.clone()),
1037 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1038 font_features: Some(theme_settings.ui_font.features.clone()),
1039 color: Some(colors.text),
1040 ..Default::default()
1041 });
1042
1043 MarkdownStyle {
1044 base_text_style: text_style,
1045 selection_background_color: colors.element_selection_background,
1046 inline_code: TextStyleRefinement {
1047 background_color: Some(colors.editor_background),
1048 ..Default::default()
1049 },
1050 link: TextStyleRefinement {
1051 color: Some(colors.text_accent),
1052 underline: Some(UnderlineStyle {
1053 color: Some(colors.text_accent.opacity(0.5)),
1054 thickness: px(1.),
1055 ..Default::default()
1056 }),
1057 ..Default::default()
1058 },
1059 syntax: cx.theme().syntax().clone(),
1060 ..Default::default()
1061 }
1062}
1063
1064/// An extension-based language model.
1065pub struct ExtensionLanguageModel {
1066 extension: WasmExtension,
1067 model_info: LlmModelInfo,
1068 provider_id: LanguageModelProviderId,
1069 provider_name: LanguageModelProviderName,
1070 provider_info: LlmProviderInfo,
1071}
1072
1073impl LanguageModel for ExtensionLanguageModel {
1074 fn id(&self) -> LanguageModelId {
1075 LanguageModelId::from(self.model_info.id.clone())
1076 }
1077
1078 fn name(&self) -> LanguageModelName {
1079 LanguageModelName::from(self.model_info.name.clone())
1080 }
1081
1082 fn provider_id(&self) -> LanguageModelProviderId {
1083 self.provider_id.clone()
1084 }
1085
1086 fn provider_name(&self) -> LanguageModelProviderName {
1087 self.provider_name.clone()
1088 }
1089
1090 fn telemetry_id(&self) -> String {
1091 format!("extension-{}", self.model_info.id)
1092 }
1093
1094 fn supports_images(&self) -> bool {
1095 self.model_info.capabilities.supports_images
1096 }
1097
1098 fn supports_tools(&self) -> bool {
1099 self.model_info.capabilities.supports_tools
1100 }
1101
1102 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1103 match choice {
1104 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1105 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1106 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1107 }
1108 }
1109
1110 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1111 match self.model_info.capabilities.tool_input_format {
1112 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1113 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1114 }
1115 }
1116
1117 fn max_token_count(&self) -> u64 {
1118 self.model_info.max_token_count
1119 }
1120
1121 fn max_output_tokens(&self) -> Option<u64> {
1122 self.model_info.max_output_tokens
1123 }
1124
1125 fn count_tokens(
1126 &self,
1127 request: LanguageModelRequest,
1128 cx: &App,
1129 ) -> BoxFuture<'static, Result<u64>> {
1130 let extension = self.extension.clone();
1131 let provider_id = self.provider_info.id.clone();
1132 let model_id = self.model_info.id.clone();
1133
1134 let wit_request = convert_request_to_wit(request);
1135
1136 cx.background_spawn(async move {
1137 extension
1138 .call({
1139 let provider_id = provider_id.clone();
1140 let model_id = model_id.clone();
1141 let wit_request = wit_request.clone();
1142 |ext, store| {
1143 async move {
1144 let count = ext
1145 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1146 .await?
1147 .map_err(|e| anyhow!("{}", e))?;
1148 Ok(count)
1149 }
1150 .boxed()
1151 }
1152 })
1153 .await?
1154 })
1155 .boxed()
1156 }
1157
1158 fn stream_completion(
1159 &self,
1160 request: LanguageModelRequest,
1161 _cx: &AsyncApp,
1162 ) -> BoxFuture<
1163 'static,
1164 Result<
1165 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1166 LanguageModelCompletionError,
1167 >,
1168 > {
1169 let extension = self.extension.clone();
1170 let provider_id = self.provider_info.id.clone();
1171 let model_id = self.model_info.id.clone();
1172
1173 let wit_request = convert_request_to_wit(request);
1174
1175 async move {
1176 // Start the stream
1177 let stream_id_result = extension
1178 .call({
1179 let provider_id = provider_id.clone();
1180 let model_id = model_id.clone();
1181 let wit_request = wit_request.clone();
1182 |ext, store| {
1183 async move {
1184 let id = ext
1185 .call_llm_stream_completion_start(
1186 store,
1187 &provider_id,
1188 &model_id,
1189 &wit_request,
1190 )
1191 .await?
1192 .map_err(|e| anyhow!("{}", e))?;
1193 Ok(id)
1194 }
1195 .boxed()
1196 }
1197 })
1198 .await;
1199
1200 let stream_id = stream_id_result
1201 .map_err(LanguageModelCompletionError::Other)?
1202 .map_err(LanguageModelCompletionError::Other)?;
1203
1204 // Create a stream that polls for events
1205 let stream = futures::stream::unfold(
1206 (extension.clone(), stream_id, false),
1207 move |(extension, stream_id, done)| async move {
1208 if done {
1209 return None;
1210 }
1211
1212 let result = extension
1213 .call({
1214 let stream_id = stream_id.clone();
1215 |ext, store| {
1216 async move {
1217 let event = ext
1218 .call_llm_stream_completion_next(store, &stream_id)
1219 .await?
1220 .map_err(|e| anyhow!("{}", e))?;
1221 Ok(event)
1222 }
1223 .boxed()
1224 }
1225 })
1226 .await
1227 .and_then(|inner| inner);
1228
1229 match result {
1230 Ok(Some(event)) => {
1231 let converted = convert_completion_event(event);
1232 let is_done =
1233 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1234 Some((converted, (extension, stream_id, is_done)))
1235 }
1236 Ok(None) => {
1237 // Stream complete, close it
1238 let _ = extension
1239 .call({
1240 let stream_id = stream_id.clone();
1241 |ext, store| {
1242 async move {
1243 ext.call_llm_stream_completion_close(store, &stream_id)
1244 .await?;
1245 Ok::<(), anyhow::Error>(())
1246 }
1247 .boxed()
1248 }
1249 })
1250 .await;
1251 None
1252 }
1253 Err(e) => Some((
1254 Err(LanguageModelCompletionError::Other(e)),
1255 (extension, stream_id, true),
1256 )),
1257 }
1258 },
1259 );
1260
1261 Ok(stream.boxed())
1262 }
1263 .boxed()
1264 }
1265
1266 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1267 // Extensions can implement this via llm_cache_configuration
1268 None
1269 }
1270}
1271
1272fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1273 use language_model::{MessageContent, Role};
1274
1275 let messages: Vec<LlmRequestMessage> = request
1276 .messages
1277 .into_iter()
1278 .map(|msg| {
1279 let role = match msg.role {
1280 Role::User => LlmMessageRole::User,
1281 Role::Assistant => LlmMessageRole::Assistant,
1282 Role::System => LlmMessageRole::System,
1283 };
1284
1285 let content: Vec<LlmMessageContent> = msg
1286 .content
1287 .into_iter()
1288 .map(|c| match c {
1289 MessageContent::Text(text) => LlmMessageContent::Text(text),
1290 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1291 source: image.source.to_string(),
1292 width: Some(image.size.width.0 as u32),
1293 height: Some(image.size.height.0 as u32),
1294 }),
1295 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1296 id: tool_use.id.to_string(),
1297 name: tool_use.name.to_string(),
1298 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1299 thought_signature: tool_use.thought_signature,
1300 }),
1301 MessageContent::ToolResult(tool_result) => {
1302 let content = match tool_result.content {
1303 language_model::LanguageModelToolResultContent::Text(text) => {
1304 LlmToolResultContent::Text(text.to_string())
1305 }
1306 language_model::LanguageModelToolResultContent::Image(image) => {
1307 LlmToolResultContent::Image(LlmImageData {
1308 source: image.source.to_string(),
1309 width: Some(image.size.width.0 as u32),
1310 height: Some(image.size.height.0 as u32),
1311 })
1312 }
1313 };
1314 LlmMessageContent::ToolResult(LlmToolResult {
1315 tool_use_id: tool_result.tool_use_id.to_string(),
1316 tool_name: tool_result.tool_name.to_string(),
1317 is_error: tool_result.is_error,
1318 content,
1319 })
1320 }
1321 MessageContent::Thinking { text, signature } => {
1322 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1323 }
1324 MessageContent::RedactedThinking(data) => {
1325 LlmMessageContent::RedactedThinking(data)
1326 }
1327 })
1328 .collect();
1329
1330 LlmRequestMessage {
1331 role,
1332 content,
1333 cache: msg.cache,
1334 }
1335 })
1336 .collect();
1337
1338 let tools: Vec<LlmToolDefinition> = request
1339 .tools
1340 .into_iter()
1341 .map(|tool| LlmToolDefinition {
1342 name: tool.name,
1343 description: tool.description,
1344 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1345 })
1346 .collect();
1347
1348 let tool_choice = request.tool_choice.map(|tc| match tc {
1349 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1350 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1351 LanguageModelToolChoice::None => LlmToolChoice::None,
1352 });
1353
1354 LlmCompletionRequest {
1355 messages,
1356 tools,
1357 tool_choice,
1358 stop_sequences: request.stop,
1359 temperature: request.temperature,
1360 thinking_allowed: false,
1361 max_tokens: None,
1362 }
1363}
1364
1365fn convert_completion_event(
1366 event: LlmCompletionEvent,
1367) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1368 match event {
1369 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1370 message_id: String::new(),
1371 }),
1372 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1373 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1374 text: thinking.text,
1375 signature: thinking.signature,
1376 }),
1377 LlmCompletionEvent::RedactedThinking(data) => {
1378 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1379 }
1380 LlmCompletionEvent::ToolUse(tool_use) => {
1381 let raw_input = tool_use.input.clone();
1382 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1383 Ok(LanguageModelCompletionEvent::ToolUse(
1384 LanguageModelToolUse {
1385 id: LanguageModelToolUseId::from(tool_use.id),
1386 name: tool_use.name.into(),
1387 raw_input,
1388 input,
1389 is_input_complete: true,
1390 thought_signature: tool_use.thought_signature,
1391 },
1392 ))
1393 }
1394 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1395 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1396 id: LanguageModelToolUseId::from(error.id),
1397 tool_name: error.tool_name.into(),
1398 raw_input: error.raw_input.into(),
1399 json_parse_error: error.error,
1400 })
1401 }
1402 LlmCompletionEvent::Stop(reason) => {
1403 let stop_reason = match reason {
1404 LlmStopReason::EndTurn => StopReason::EndTurn,
1405 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1406 LlmStopReason::ToolUse => StopReason::ToolUse,
1407 LlmStopReason::Refusal => StopReason::Refusal,
1408 };
1409 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1410 }
1411 LlmCompletionEvent::Usage(usage) => {
1412 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1413 input_tokens: usage.input_tokens,
1414 output_tokens: usage.output_tokens,
1415 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1416 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1417 }))
1418 }
1419 LlmCompletionEvent::ReasoningDetails(json) => {
1420 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1421 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1422 ))
1423 }
1424 }
1425}