1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::{LanguageModelAuthConfig, OAuthConfig};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
20 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 // First check cached state
183 if self.state.read(cx).is_authenticated {
184 return true;
185 }
186
187 // Also check env var dynamically (in case migration happened after provider creation)
188 if let Some(ref auth_config) = self.auth_config {
189 if let Some(ref env_var_name) = auth_config.env_var {
190 let provider_id_string = self.provider_id_string();
191 let env_var_allowed = ExtensionSettings::get_global(cx)
192 .allowed_env_var_providers
193 .contains(provider_id_string.as_str());
194
195 if env_var_allowed {
196 if let Ok(value) = std::env::var(env_var_name) {
197 if !value.is_empty() {
198 return true;
199 }
200 }
201 }
202 }
203 }
204
205 false
206 }
207
208 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
209 // Check if already authenticated via is_authenticated
210 if self.is_authenticated(cx) {
211 return Task::ready(Ok(()));
212 }
213
214 // Not authenticated - return error indicating credentials not found
215 Task::ready(Err(AuthenticateError::CredentialsNotFound))
216 }
217
218 fn configuration_view(
219 &self,
220 _target_agent: ConfigurationViewTargetAgent,
221 window: &mut Window,
222 cx: &mut App,
223 ) -> AnyView {
224 let credential_key = self.credential_key();
225 let extension = self.extension.clone();
226 let extension_provider_id = self.provider_info.id.clone();
227 let full_provider_id = self.provider_id_string();
228 let state = self.state.clone();
229 let auth_config = self.auth_config.clone();
230
231 cx.new(|cx| {
232 ExtensionProviderConfigurationView::new(
233 credential_key,
234 extension,
235 extension_provider_id,
236 full_provider_id,
237 auth_config,
238 state,
239 window,
240 cx,
241 )
242 })
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 let extension = self.extension.clone();
248 let provider_id = self.provider_info.id.clone();
249 let state = self.state.clone();
250 let credential_key = self.credential_key();
251
252 let credentials_provider = <dyn CredentialsProvider>::global(cx);
253
254 cx.spawn(async move |cx| {
255 // Delete from system keychain
256 credentials_provider
257 .delete_credentials(&credential_key, cx)
258 .await
259 .log_err();
260
261 // Call extension's reset_credentials
262 let result = extension
263 .call(|extension, store| {
264 async move {
265 extension
266 .call_llm_provider_reset_credentials(store, &provider_id)
267 .await
268 }
269 .boxed()
270 })
271 .await;
272
273 // Update state
274 cx.update(|cx| {
275 state.update(cx, |state, _| {
276 state.is_authenticated = false;
277 });
278 })?;
279
280 match result {
281 Ok(Ok(Ok(()))) => Ok(()),
282 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
283 Ok(Err(e)) => Err(e),
284 Err(e) => Err(e),
285 }
286 })
287 }
288}
289
290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
291 type ObservableEntity = ExtensionLlmProviderState;
292
293 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
294 Some(self.state.clone())
295 }
296
297 fn subscribe<T: 'static>(
298 &self,
299 cx: &mut Context<T>,
300 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
301 ) -> Option<Subscription> {
302 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
303 }
304}
305
306/// Configuration view for extension-based LLM providers.
307struct ExtensionProviderConfigurationView {
308 credential_key: String,
309 extension: WasmExtension,
310 extension_provider_id: String,
311 full_provider_id: String,
312 auth_config: Option<LanguageModelAuthConfig>,
313 state: Entity<ExtensionLlmProviderState>,
314 settings_markdown: Option<Entity<Markdown>>,
315 api_key_editor: Entity<Editor>,
316 loading_settings: bool,
317 loading_credentials: bool,
318 oauth_in_progress: bool,
319 oauth_error: Option<String>,
320 device_user_code: Option<String>,
321 _subscriptions: Vec<Subscription>,
322}
323
324impl ExtensionProviderConfigurationView {
325 fn new(
326 credential_key: String,
327 extension: WasmExtension,
328 extension_provider_id: String,
329 full_provider_id: String,
330 auth_config: Option<LanguageModelAuthConfig>,
331 state: Entity<ExtensionLlmProviderState>,
332 window: &mut Window,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 // Subscribe to state changes
336 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
337 cx.notify();
338 });
339
340 // Create API key editor
341 let api_key_editor = cx.new(|cx| {
342 let mut editor = Editor::single_line(window, cx);
343 editor.set_placeholder_text("Enter API key...", window, cx);
344 editor
345 });
346
347 let mut this = Self {
348 credential_key,
349 extension,
350 extension_provider_id,
351 full_provider_id,
352 auth_config,
353 state,
354 settings_markdown: None,
355 api_key_editor,
356 loading_settings: true,
357 loading_credentials: true,
358 oauth_in_progress: false,
359 oauth_error: None,
360 device_user_code: None,
361 _subscriptions: vec![state_subscription],
362 };
363
364 // Load settings text from extension
365 this.load_settings_text(cx);
366
367 // Load existing credentials
368 this.load_credentials(cx);
369
370 this
371 }
372
373 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
374 let extension = self.extension.clone();
375 let provider_id = self.extension_provider_id.clone();
376
377 cx.spawn(async move |this, cx| {
378 let result = extension
379 .call({
380 let provider_id = provider_id.clone();
381 |ext, store| {
382 async move {
383 ext.call_llm_provider_settings_markdown(store, &provider_id)
384 .await
385 }
386 .boxed()
387 }
388 })
389 .await;
390
391 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
392
393 this.update(cx, |this, cx| {
394 this.loading_settings = false;
395 if let Some(text) = settings_text {
396 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
397 this.settings_markdown = Some(markdown);
398 }
399 cx.notify();
400 })
401 .log_err();
402 })
403 .detach();
404 }
405
406 fn load_credentials(&mut self, cx: &mut Context<Self>) {
407 let credential_key = self.credential_key.clone();
408 let credentials_provider = <dyn CredentialsProvider>::global(cx);
409 let state = self.state.clone();
410
411 // Check if we should use env var (already set in state during provider construction)
412 let api_key_from_env = self.state.read(cx).api_key_from_env;
413
414 cx.spawn(async move |this, cx| {
415 // If using env var, we're already authenticated
416 if api_key_from_env {
417 this.update(cx, |this, cx| {
418 this.loading_credentials = false;
419 cx.notify();
420 })
421 .log_err();
422 return;
423 }
424
425 let credentials = credentials_provider
426 .read_credentials(&credential_key, cx)
427 .await
428 .log_err()
429 .flatten();
430
431 let has_credentials = credentials.is_some();
432
433 // Update authentication state based on stored credentials
434 cx.update(|cx| {
435 state.update(cx, |state, cx| {
436 state.is_authenticated = has_credentials;
437 cx.notify();
438 });
439 })
440 .log_err();
441
442 this.update(cx, |this, cx| {
443 this.loading_credentials = false;
444 cx.notify();
445 })
446 .log_err();
447 })
448 .detach();
449 }
450
451 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
452 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
453 let env_var_name = match &self.auth_config {
454 Some(config) => config.env_var.clone(),
455 None => return,
456 };
457
458 let state = self.state.clone();
459 let currently_allowed = self.state.read(cx).env_var_allowed;
460
461 // Update settings file
462 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
463 let providers = settings
464 .extension
465 .allowed_env_var_providers
466 .get_or_insert_with(Vec::new);
467
468 if currently_allowed {
469 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
470 } else {
471 if !providers
472 .iter()
473 .any(|id| id.as_ref() == full_provider_id.as_ref())
474 {
475 providers.push(full_provider_id.clone());
476 }
477 }
478 });
479
480 // Update local state
481 let new_allowed = !currently_allowed;
482 let new_from_env = if new_allowed {
483 if let Some(var_name) = &env_var_name {
484 if let Ok(value) = std::env::var(var_name) {
485 !value.is_empty()
486 } else {
487 false
488 }
489 } else {
490 false
491 }
492 } else {
493 false
494 };
495
496 state.update(cx, |state, cx| {
497 state.env_var_allowed = new_allowed;
498 state.api_key_from_env = new_from_env;
499 if new_from_env {
500 state.is_authenticated = true;
501 }
502 cx.notify();
503 });
504
505 // If env var is being enabled, clear any stored keychain credentials
506 // so there's only one source of truth for the API key
507 if new_allowed {
508 let credential_key = self.credential_key.clone();
509 let credentials_provider = <dyn CredentialsProvider>::global(cx);
510 cx.spawn(async move |_this, cx| {
511 credentials_provider
512 .delete_credentials(&credential_key, cx)
513 .await
514 .log_err();
515 })
516 .detach();
517 }
518
519 // If env var is being disabled, reload credentials from keychain
520 if !new_allowed {
521 self.reload_keychain_credentials(cx);
522 }
523
524 cx.notify();
525 }
526
527 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
528 let credential_key = self.credential_key.clone();
529 let credentials_provider = <dyn CredentialsProvider>::global(cx);
530 let state = self.state.clone();
531
532 cx.spawn(async move |_this, cx| {
533 let credentials = credentials_provider
534 .read_credentials(&credential_key, cx)
535 .await
536 .log_err()
537 .flatten();
538
539 let has_credentials = credentials.is_some();
540
541 cx.update(|cx| {
542 state.update(cx, |state, cx| {
543 state.is_authenticated = has_credentials;
544 cx.notify();
545 });
546 })
547 .log_err();
548 })
549 .detach();
550 }
551
552 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
553 let api_key = self.api_key_editor.read(cx).text(cx);
554 if api_key.is_empty() {
555 return;
556 }
557
558 // Clear the editor
559 self.api_key_editor
560 .update(cx, |editor, cx| editor.set_text("", window, cx));
561
562 let credential_key = self.credential_key.clone();
563 let credentials_provider = <dyn CredentialsProvider>::global(cx);
564 let state = self.state.clone();
565
566 cx.spawn(async move |_this, cx| {
567 // Store in system keychain
568 credentials_provider
569 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
570 .await
571 .log_err();
572
573 // Update state to authenticated
574 cx.update(|cx| {
575 state.update(cx, |state, cx| {
576 state.is_authenticated = true;
577 cx.notify();
578 });
579 })
580 .log_err();
581 })
582 .detach();
583 }
584
585 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
586 // Clear the editor
587 self.api_key_editor
588 .update(cx, |editor, cx| editor.set_text("", window, cx));
589
590 let credential_key = self.credential_key.clone();
591 let credentials_provider = <dyn CredentialsProvider>::global(cx);
592 let state = self.state.clone();
593
594 cx.spawn(async move |_this, cx| {
595 // Delete from system keychain
596 credentials_provider
597 .delete_credentials(&credential_key, cx)
598 .await
599 .log_err();
600
601 // Update state to unauthenticated
602 cx.update(|cx| {
603 state.update(cx, |state, cx| {
604 state.is_authenticated = false;
605 cx.notify();
606 });
607 })
608 .log_err();
609 })
610 .detach();
611 }
612
613 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
614 if self.oauth_in_progress {
615 return;
616 }
617
618 self.oauth_in_progress = true;
619 self.oauth_error = None;
620 self.device_user_code = None;
621 cx.notify();
622
623 let extension = self.extension.clone();
624 let provider_id = self.extension_provider_id.clone();
625 let state = self.state.clone();
626
627 cx.spawn(async move |this, cx| {
628 // Step 1: Start device flow - opens browser and returns user code
629 let start_result = extension
630 .call({
631 let provider_id = provider_id.clone();
632 |ext, store| {
633 async move {
634 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
635 .await
636 }
637 .boxed()
638 }
639 })
640 .await;
641
642 let user_code = match start_result {
643 Ok(Ok(Ok(code))) => code,
644 Ok(Ok(Err(e))) => {
645 log::error!("Device flow start failed: {}", e);
646 this.update(cx, |this, cx| {
647 this.oauth_in_progress = false;
648 this.oauth_error = Some(e);
649 cx.notify();
650 })
651 .log_err();
652 return;
653 }
654 Ok(Err(e)) | Err(e) => {
655 log::error!("Device flow start error: {}", e);
656 this.update(cx, |this, cx| {
657 this.oauth_in_progress = false;
658 this.oauth_error = Some(e.to_string());
659 cx.notify();
660 })
661 .log_err();
662 return;
663 }
664 };
665
666 // Update UI to show the user code before polling
667 this.update(cx, |this, cx| {
668 this.device_user_code = Some(user_code);
669 cx.notify();
670 })
671 .log_err();
672
673 // Step 2: Poll for authentication completion
674 let poll_result = extension
675 .call({
676 let provider_id = provider_id.clone();
677 |ext, store| {
678 async move {
679 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
680 .await
681 }
682 .boxed()
683 }
684 })
685 .await;
686
687 let error_message = match poll_result {
688 Ok(Ok(Ok(()))) => {
689 // After successful auth, refresh the models list
690 let models_result = extension
691 .call({
692 let provider_id = provider_id.clone();
693 |ext, store| {
694 async move {
695 ext.call_llm_provider_models(store, &provider_id).await
696 }
697 .boxed()
698 }
699 })
700 .await;
701
702 let new_models: Vec<LlmModelInfo> = match models_result {
703 Ok(Ok(Ok(models))) => models,
704 _ => Vec::new(),
705 };
706
707 cx.update(|cx| {
708 state.update(cx, |state, cx| {
709 state.is_authenticated = true;
710 state.available_models = new_models;
711 cx.notify();
712 });
713 })
714 .log_err();
715 None
716 }
717 Ok(Ok(Err(e))) => {
718 log::error!("Device flow poll failed: {}", e);
719 Some(e)
720 }
721 Ok(Err(e)) | Err(e) => {
722 log::error!("Device flow poll error: {}", e);
723 Some(e.to_string())
724 }
725 };
726
727 this.update(cx, |this, cx| {
728 this.oauth_in_progress = false;
729 this.oauth_error = error_message;
730 this.device_user_code = None;
731 cx.notify();
732 })
733 .log_err();
734 })
735 .detach();
736 }
737
738 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
739 self.state.read(cx).is_authenticated
740 }
741
742 fn has_oauth_config(&self) -> bool {
743 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
744 }
745
746 fn oauth_config(&self) -> Option<&OAuthConfig> {
747 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
748 }
749
750 fn has_api_key_config(&self) -> bool {
751 // API key is available if there's a credential_label or no oauth-only config
752 self.auth_config
753 .as_ref()
754 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
755 .unwrap_or(true)
756 }
757}
758
759impl gpui::Render for ExtensionProviderConfigurationView {
760 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
761 let is_loading = self.loading_settings || self.loading_credentials;
762 let is_authenticated = self.is_authenticated(cx);
763 let env_var_allowed = self.state.read(cx).env_var_allowed;
764 let api_key_from_env = self.state.read(cx).api_key_from_env;
765 let has_oauth = self.has_oauth_config();
766 let has_api_key = self.has_api_key_config();
767
768 if is_loading {
769 return v_flex()
770 .gap_2()
771 .child(Label::new("Loading...").color(Color::Muted))
772 .into_any_element();
773 }
774
775 let mut content = v_flex().gap_4().size_full();
776
777 // Render settings markdown if available
778 if let Some(markdown) = &self.settings_markdown {
779 let style = settings_markdown_style(_window, cx);
780 content = content.child(MarkdownElement::new(markdown.clone(), style));
781 }
782
783 // Render env var checkbox if the extension specifies an env var
784 if let Some(auth_config) = &self.auth_config {
785 if let Some(env_var_name) = &auth_config.env_var {
786 let env_var_name = env_var_name.clone();
787 let checkbox_label =
788 format!("Read API key from {} environment variable", env_var_name);
789
790 content = content.child(
791 h_flex()
792 .gap_2()
793 .child(
794 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
795 .on_click(cx.listener(|this, _, _window, cx| {
796 this.toggle_env_var_permission(cx);
797 })),
798 )
799 .child(Label::new(checkbox_label).size(LabelSize::Small)),
800 );
801
802 // Show status if env var is allowed
803 if env_var_allowed {
804 if api_key_from_env {
805 let tooltip_label = format!(
806 "To reset this API key, unset the {} environment variable.",
807 env_var_name
808 );
809 content = content.child(
810 h_flex()
811 .mt_0p5()
812 .p_1()
813 .justify_between()
814 .rounded_md()
815 .border_1()
816 .border_color(cx.theme().colors().border)
817 .bg(cx.theme().colors().background)
818 .child(
819 h_flex()
820 .flex_1()
821 .min_w_0()
822 .gap_1()
823 .child(
824 ui::Icon::new(ui::IconName::Check)
825 .color(Color::Success),
826 )
827 .child(
828 Label::new(format!(
829 "API key set in {} environment variable",
830 env_var_name
831 ))
832 .truncate(),
833 ),
834 )
835 .child(
836 ui::Button::new("reset-key", "Reset Key")
837 .label_size(LabelSize::Small)
838 .icon(ui::IconName::Undo)
839 .icon_size(ui::IconSize::Small)
840 .icon_color(Color::Muted)
841 .icon_position(ui::IconPosition::Start)
842 .disabled(true)
843 .tooltip(ui::Tooltip::text(tooltip_label)),
844 ),
845 );
846 return content.into_any_element();
847 } else {
848 content = content.child(
849 h_flex()
850 .gap_2()
851 .child(
852 ui::Icon::new(ui::IconName::Warning)
853 .color(Color::Warning)
854 .size(ui::IconSize::Small),
855 )
856 .child(
857 Label::new(format!(
858 "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
859 env_var_name
860 ))
861 .color(Color::Warning)
862 .size(LabelSize::Small),
863 ),
864 );
865 }
866 }
867 }
868 }
869
870 // If authenticated, show success state with sign out option
871 if is_authenticated && !api_key_from_env {
872 let reset_label = if has_oauth && !has_api_key {
873 "Sign Out"
874 } else {
875 "Reset Key"
876 };
877
878 let status_label = if has_oauth && !has_api_key {
879 "Signed in"
880 } else {
881 "API key configured"
882 };
883
884 content = content.child(
885 h_flex()
886 .mt_0p5()
887 .p_1()
888 .justify_between()
889 .rounded_md()
890 .border_1()
891 .border_color(cx.theme().colors().border)
892 .bg(cx.theme().colors().background)
893 .child(
894 h_flex()
895 .flex_1()
896 .min_w_0()
897 .gap_1()
898 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
899 .child(Label::new(status_label).truncate()),
900 )
901 .child(
902 ui::Button::new("reset-key", reset_label)
903 .label_size(LabelSize::Small)
904 .icon(ui::IconName::Undo)
905 .icon_size(ui::IconSize::Small)
906 .icon_color(Color::Muted)
907 .icon_position(ui::IconPosition::Start)
908 .on_click(cx.listener(|this, _, window, cx| {
909 this.reset_api_key(window, cx);
910 })),
911 ),
912 );
913
914 return content.into_any_element();
915 }
916
917 // Not authenticated - show available auth options
918 if !api_key_from_env {
919 // Render OAuth sign-in button if configured
920 if has_oauth {
921 let oauth_config = self.oauth_config();
922 let button_label = oauth_config
923 .and_then(|c| c.sign_in_button_label.clone())
924 .unwrap_or_else(|| "Sign In".to_string());
925 let button_icon = oauth_config
926 .and_then(|c| c.sign_in_button_icon.as_ref())
927 .and_then(|icon_name| match icon_name.as_str() {
928 "github" => Some(ui::IconName::Github),
929 _ => None,
930 });
931
932 let oauth_in_progress = self.oauth_in_progress;
933
934 let oauth_error = self.oauth_error.clone();
935
936 content = content.child(
937 v_flex()
938 .gap_2()
939 .child({
940 let mut button = ui::Button::new("oauth-sign-in", button_label)
941 .full_width()
942 .style(ui::ButtonStyle::Outlined)
943 .disabled(oauth_in_progress)
944 .on_click(cx.listener(|this, _, _window, cx| {
945 this.start_oauth_sign_in(cx);
946 }));
947 if let Some(icon) = button_icon {
948 button = button
949 .icon(icon)
950 .icon_position(ui::IconPosition::Start)
951 .icon_size(ui::IconSize::Small)
952 .icon_color(Color::Muted);
953 }
954 button
955 })
956 .when(oauth_in_progress, |this| {
957 let user_code = self.device_user_code.clone();
958 this.child(
959 v_flex()
960 .gap_1()
961 .when_some(user_code, |this, code| {
962 let copied = cx
963 .read_from_clipboard()
964 .map(|item| item.text().as_ref() == Some(&code))
965 .unwrap_or(false);
966 let code_for_click = code.clone();
967 this.child(
968 h_flex()
969 .gap_1()
970 .child(
971 Label::new("Enter code:")
972 .size(LabelSize::Small)
973 .color(Color::Muted),
974 )
975 .child(
976 h_flex()
977 .gap_1()
978 .px_1()
979 .border_1()
980 .border_color(cx.theme().colors().border)
981 .rounded_sm()
982 .cursor_pointer()
983 .on_mouse_down(
984 MouseButton::Left,
985 move |_, window, cx| {
986 cx.write_to_clipboard(
987 ClipboardItem::new_string(
988 code_for_click.clone(),
989 ),
990 );
991 window.refresh();
992 },
993 )
994 .child(
995 Label::new(code)
996 .size(LabelSize::Small)
997 .color(Color::Accent),
998 )
999 .child(
1000 ui::Icon::new(if copied {
1001 ui::IconName::Check
1002 } else {
1003 ui::IconName::Copy
1004 })
1005 .size(ui::IconSize::Small)
1006 .color(if copied {
1007 Color::Success
1008 } else {
1009 Color::Muted
1010 }),
1011 ),
1012 ),
1013 )
1014 })
1015 .child(
1016 Label::new("Waiting for authorization in browser...")
1017 .size(LabelSize::Small)
1018 .color(Color::Muted),
1019 ),
1020 )
1021 })
1022 .when_some(oauth_error, |this, error| {
1023 this.child(
1024 v_flex()
1025 .gap_1()
1026 .child(
1027 h_flex()
1028 .gap_2()
1029 .child(
1030 ui::Icon::new(ui::IconName::Warning)
1031 .color(Color::Error)
1032 .size(ui::IconSize::Small),
1033 )
1034 .child(
1035 Label::new("Authentication failed")
1036 .color(Color::Error)
1037 .size(LabelSize::Small),
1038 ),
1039 )
1040 .child(
1041 div().pl_6().child(
1042 Label::new(error)
1043 .color(Color::Error)
1044 .size(LabelSize::Small),
1045 ),
1046 ),
1047 )
1048 }),
1049 );
1050 }
1051
1052 // Render API key input if configured (and we have both options, show a separator)
1053 if has_api_key {
1054 if has_oauth {
1055 content = content.child(
1056 h_flex()
1057 .gap_2()
1058 .items_center()
1059 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1060 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1061 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1062 );
1063 }
1064
1065 let credential_label = self
1066 .auth_config
1067 .as_ref()
1068 .and_then(|c| c.credential_label.clone())
1069 .unwrap_or_else(|| "API Key".to_string());
1070
1071 content = content.child(
1072 v_flex()
1073 .gap_2()
1074 .on_action(cx.listener(Self::save_api_key))
1075 .child(
1076 Label::new(credential_label)
1077 .size(LabelSize::Small)
1078 .color(Color::Muted),
1079 )
1080 .child(self.api_key_editor.clone())
1081 .child(
1082 Label::new("Enter your API key and press Enter to save")
1083 .size(LabelSize::Small)
1084 .color(Color::Muted),
1085 ),
1086 );
1087 }
1088 }
1089
1090 content.into_any_element()
1091 }
1092}
1093
1094impl Focusable for ExtensionProviderConfigurationView {
1095 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1096 self.api_key_editor.focus_handle(cx)
1097 }
1098}
1099
1100fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1101 let theme_settings = ThemeSettings::get_global(cx);
1102 let colors = cx.theme().colors();
1103 let mut text_style = window.text_style();
1104 text_style.refine(&TextStyleRefinement {
1105 font_family: Some(theme_settings.ui_font.family.clone()),
1106 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1107 font_features: Some(theme_settings.ui_font.features.clone()),
1108 color: Some(colors.text),
1109 ..Default::default()
1110 });
1111
1112 MarkdownStyle {
1113 base_text_style: text_style,
1114 selection_background_color: colors.element_selection_background,
1115 inline_code: TextStyleRefinement {
1116 background_color: Some(colors.editor_background),
1117 ..Default::default()
1118 },
1119 link: TextStyleRefinement {
1120 color: Some(colors.text_accent),
1121 underline: Some(UnderlineStyle {
1122 color: Some(colors.text_accent.opacity(0.5)),
1123 thickness: px(1.),
1124 ..Default::default()
1125 }),
1126 ..Default::default()
1127 },
1128 syntax: cx.theme().syntax().clone(),
1129 ..Default::default()
1130 }
1131}
1132
1133/// An extension-based language model.
1134pub struct ExtensionLanguageModel {
1135 extension: WasmExtension,
1136 model_info: LlmModelInfo,
1137 provider_id: LanguageModelProviderId,
1138 provider_name: LanguageModelProviderName,
1139 provider_info: LlmProviderInfo,
1140}
1141
1142impl LanguageModel for ExtensionLanguageModel {
1143 fn id(&self) -> LanguageModelId {
1144 LanguageModelId::from(self.model_info.id.clone())
1145 }
1146
1147 fn name(&self) -> LanguageModelName {
1148 LanguageModelName::from(self.model_info.name.clone())
1149 }
1150
1151 fn provider_id(&self) -> LanguageModelProviderId {
1152 self.provider_id.clone()
1153 }
1154
1155 fn provider_name(&self) -> LanguageModelProviderName {
1156 self.provider_name.clone()
1157 }
1158
1159 fn telemetry_id(&self) -> String {
1160 format!("extension-{}", self.model_info.id)
1161 }
1162
1163 fn supports_images(&self) -> bool {
1164 self.model_info.capabilities.supports_images
1165 }
1166
1167 fn supports_tools(&self) -> bool {
1168 self.model_info.capabilities.supports_tools
1169 }
1170
1171 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1172 match choice {
1173 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1174 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1175 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1176 }
1177 }
1178
1179 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1180 match self.model_info.capabilities.tool_input_format {
1181 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1182 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1183 }
1184 }
1185
1186 fn max_token_count(&self) -> u64 {
1187 self.model_info.max_token_count
1188 }
1189
1190 fn max_output_tokens(&self) -> Option<u64> {
1191 self.model_info.max_output_tokens
1192 }
1193
1194 fn count_tokens(
1195 &self,
1196 request: LanguageModelRequest,
1197 cx: &App,
1198 ) -> BoxFuture<'static, Result<u64>> {
1199 let extension = self.extension.clone();
1200 let provider_id = self.provider_info.id.clone();
1201 let model_id = self.model_info.id.clone();
1202
1203 let wit_request = convert_request_to_wit(request);
1204
1205 cx.background_spawn(async move {
1206 extension
1207 .call({
1208 let provider_id = provider_id.clone();
1209 let model_id = model_id.clone();
1210 let wit_request = wit_request.clone();
1211 |ext, store| {
1212 async move {
1213 let count = ext
1214 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1215 .await?
1216 .map_err(|e| anyhow!("{}", e))?;
1217 Ok(count)
1218 }
1219 .boxed()
1220 }
1221 })
1222 .await?
1223 })
1224 .boxed()
1225 }
1226
1227 fn stream_completion(
1228 &self,
1229 request: LanguageModelRequest,
1230 _cx: &AsyncApp,
1231 ) -> BoxFuture<
1232 'static,
1233 Result<
1234 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1235 LanguageModelCompletionError,
1236 >,
1237 > {
1238 let extension = self.extension.clone();
1239 let provider_id = self.provider_info.id.clone();
1240 let model_id = self.model_info.id.clone();
1241
1242 let wit_request = convert_request_to_wit(request);
1243
1244 async move {
1245 // Start the stream
1246 let stream_id_result = extension
1247 .call({
1248 let provider_id = provider_id.clone();
1249 let model_id = model_id.clone();
1250 let wit_request = wit_request.clone();
1251 |ext, store| {
1252 async move {
1253 let id = ext
1254 .call_llm_stream_completion_start(
1255 store,
1256 &provider_id,
1257 &model_id,
1258 &wit_request,
1259 )
1260 .await?
1261 .map_err(|e| anyhow!("{}", e))?;
1262 Ok(id)
1263 }
1264 .boxed()
1265 }
1266 })
1267 .await;
1268
1269 let stream_id = stream_id_result
1270 .map_err(LanguageModelCompletionError::Other)?
1271 .map_err(LanguageModelCompletionError::Other)?;
1272
1273 // Create a stream that polls for events
1274 let stream = futures::stream::unfold(
1275 (extension.clone(), stream_id, false),
1276 move |(extension, stream_id, done)| async move {
1277 if done {
1278 return None;
1279 }
1280
1281 let result = extension
1282 .call({
1283 let stream_id = stream_id.clone();
1284 |ext, store| {
1285 async move {
1286 let event = ext
1287 .call_llm_stream_completion_next(store, &stream_id)
1288 .await?
1289 .map_err(|e| anyhow!("{}", e))?;
1290 Ok(event)
1291 }
1292 .boxed()
1293 }
1294 })
1295 .await
1296 .and_then(|inner| inner);
1297
1298 match result {
1299 Ok(Some(event)) => {
1300 let converted = convert_completion_event(event);
1301 let is_done =
1302 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1303 Some((converted, (extension, stream_id, is_done)))
1304 }
1305 Ok(None) => {
1306 // Stream complete, close it
1307 let _ = extension
1308 .call({
1309 let stream_id = stream_id.clone();
1310 |ext, store| {
1311 async move {
1312 ext.call_llm_stream_completion_close(store, &stream_id)
1313 .await?;
1314 Ok::<(), anyhow::Error>(())
1315 }
1316 .boxed()
1317 }
1318 })
1319 .await;
1320 None
1321 }
1322 Err(e) => Some((
1323 Err(LanguageModelCompletionError::Other(e)),
1324 (extension, stream_id, true),
1325 )),
1326 }
1327 },
1328 );
1329
1330 Ok(stream.boxed())
1331 }
1332 .boxed()
1333 }
1334
1335 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1336 // Extensions can implement this via llm_cache_configuration
1337 None
1338 }
1339}
1340
1341fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1342 use language_model::{MessageContent, Role};
1343
1344 let messages: Vec<LlmRequestMessage> = request
1345 .messages
1346 .into_iter()
1347 .map(|msg| {
1348 let role = match msg.role {
1349 Role::User => LlmMessageRole::User,
1350 Role::Assistant => LlmMessageRole::Assistant,
1351 Role::System => LlmMessageRole::System,
1352 };
1353
1354 let content: Vec<LlmMessageContent> = msg
1355 .content
1356 .into_iter()
1357 .map(|c| match c {
1358 MessageContent::Text(text) => LlmMessageContent::Text(text),
1359 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1360 source: image.source.to_string(),
1361 width: Some(image.size.width.0 as u32),
1362 height: Some(image.size.height.0 as u32),
1363 }),
1364 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1365 id: tool_use.id.to_string(),
1366 name: tool_use.name.to_string(),
1367 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1368 thought_signature: tool_use.thought_signature,
1369 }),
1370 MessageContent::ToolResult(tool_result) => {
1371 let content = match tool_result.content {
1372 language_model::LanguageModelToolResultContent::Text(text) => {
1373 LlmToolResultContent::Text(text.to_string())
1374 }
1375 language_model::LanguageModelToolResultContent::Image(image) => {
1376 LlmToolResultContent::Image(LlmImageData {
1377 source: image.source.to_string(),
1378 width: Some(image.size.width.0 as u32),
1379 height: Some(image.size.height.0 as u32),
1380 })
1381 }
1382 };
1383 LlmMessageContent::ToolResult(LlmToolResult {
1384 tool_use_id: tool_result.tool_use_id.to_string(),
1385 tool_name: tool_result.tool_name.to_string(),
1386 is_error: tool_result.is_error,
1387 content,
1388 })
1389 }
1390 MessageContent::Thinking { text, signature } => {
1391 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1392 }
1393 MessageContent::RedactedThinking(data) => {
1394 LlmMessageContent::RedactedThinking(data)
1395 }
1396 })
1397 .collect();
1398
1399 LlmRequestMessage {
1400 role,
1401 content,
1402 cache: msg.cache,
1403 }
1404 })
1405 .collect();
1406
1407 let tools: Vec<LlmToolDefinition> = request
1408 .tools
1409 .into_iter()
1410 .map(|tool| LlmToolDefinition {
1411 name: tool.name,
1412 description: tool.description,
1413 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1414 })
1415 .collect();
1416
1417 let tool_choice = request.tool_choice.map(|tc| match tc {
1418 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1419 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1420 LanguageModelToolChoice::None => LlmToolChoice::None,
1421 });
1422
1423 LlmCompletionRequest {
1424 messages,
1425 tools,
1426 tool_choice,
1427 stop_sequences: request.stop,
1428 temperature: request.temperature,
1429 thinking_allowed: false,
1430 max_tokens: None,
1431 }
1432}
1433
1434fn convert_completion_event(
1435 event: LlmCompletionEvent,
1436) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1437 match event {
1438 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1439 message_id: String::new(),
1440 }),
1441 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1442 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1443 text: thinking.text,
1444 signature: thinking.signature,
1445 }),
1446 LlmCompletionEvent::RedactedThinking(data) => {
1447 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1448 }
1449 LlmCompletionEvent::ToolUse(tool_use) => {
1450 let raw_input = tool_use.input.clone();
1451 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1452 Ok(LanguageModelCompletionEvent::ToolUse(
1453 LanguageModelToolUse {
1454 id: LanguageModelToolUseId::from(tool_use.id),
1455 name: tool_use.name.into(),
1456 raw_input,
1457 input,
1458 is_input_complete: true,
1459 thought_signature: tool_use.thought_signature,
1460 },
1461 ))
1462 }
1463 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1464 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1465 id: LanguageModelToolUseId::from(error.id),
1466 tool_name: error.tool_name.into(),
1467 raw_input: error.raw_input.into(),
1468 json_parse_error: error.error,
1469 })
1470 }
1471 LlmCompletionEvent::Stop(reason) => {
1472 let stop_reason = match reason {
1473 LlmStopReason::EndTurn => StopReason::EndTurn,
1474 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1475 LlmStopReason::ToolUse => StopReason::ToolUse,
1476 LlmStopReason::Refusal => StopReason::Refusal,
1477 };
1478 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1479 }
1480 LlmCompletionEvent::Usage(usage) => {
1481 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1482 input_tokens: usage.input_tokens,
1483 output_tokens: usage.output_tokens,
1484 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1485 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1486 }))
1487 }
1488 LlmCompletionEvent::ReasoningDetails(json) => {
1489 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1490 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1491 ))
1492 }
1493 }
1494}