1use crate::ExtensionSettings;
2use crate::LEGACY_LLM_EXTENSION_IDS;
3use crate::wasm_host::WasmExtension;
4use collections::HashSet;
5
6use crate::wasm_host::wit::{
7 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
8 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
9 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
10 LlmToolUse,
11};
12use anyhow::{Result, anyhow};
13use credentials_provider::CredentialsProvider;
14use editor::Editor;
15use extension::{LanguageModelAuthConfig, OAuthConfig};
16use futures::future::BoxFuture;
17use futures::stream::BoxStream;
18use futures::{FutureExt, StreamExt};
19use gpui::Focusable;
20use gpui::{
21 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
22 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
23};
24use language_model::tool_schema::LanguageModelToolSchemaFormat;
25use language_model::{
26 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
27 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
28 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
29 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
30 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
31};
32use markdown::{Markdown, MarkdownElement, MarkdownStyle};
33use settings::Settings;
34use std::sync::Arc;
35use theme::ThemeSettings;
36use ui::{Label, LabelSize, prelude::*};
37use util::ResultExt as _;
38
39/// An extension-based language model provider.
40pub struct ExtensionLanguageModelProvider {
41 pub extension: WasmExtension,
42 pub provider_info: LlmProviderInfo,
43 icon_path: Option<SharedString>,
44 auth_config: Option<LanguageModelAuthConfig>,
45 state: Entity<ExtensionLlmProviderState>,
46}
47
48pub struct ExtensionLlmProviderState {
49 is_authenticated: bool,
50 available_models: Vec<LlmModelInfo>,
51 /// Set of env var names that are allowed to be read for this provider.
52 allowed_env_vars: HashSet<String>,
53 /// If authenticated via env var, which one was used.
54 env_var_name_used: Option<String>,
55}
56
57impl EventEmitter<()> for ExtensionLlmProviderState {}
58
59impl ExtensionLanguageModelProvider {
60 pub fn new(
61 extension: WasmExtension,
62 provider_info: LlmProviderInfo,
63 models: Vec<LlmModelInfo>,
64 is_authenticated: bool,
65 icon_path: Option<SharedString>,
66 auth_config: Option<LanguageModelAuthConfig>,
67 cx: &mut App,
68 ) -> Self {
69 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
70
71 // Build set of allowed env vars for this provider
72 let settings = ExtensionSettings::get_global(cx);
73 let is_legacy_extension =
74 LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
75
76 let mut allowed_env_vars = HashSet::default();
77 if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
78 for env_var_name in env_vars {
79 let key = format!("{}:{}", provider_id_string, env_var_name);
80 // For legacy extensions, auto-allow if env var is set (migration will persist this)
81 let env_var_is_set = std::env::var(env_var_name)
82 .map(|v| !v.is_empty())
83 .unwrap_or(false);
84 if settings.allowed_env_var_providers.contains(key.as_str())
85 || (is_legacy_extension && env_var_is_set)
86 {
87 allowed_env_vars.insert(env_var_name.clone());
88 }
89 }
90 }
91
92 // Check if any allowed env var is set
93 let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
94 if let Ok(value) = std::env::var(env_var_name) {
95 if !value.is_empty() {
96 return Some(env_var_name.clone());
97 }
98 }
99 None
100 });
101
102 let is_authenticated = if env_var_name_used.is_some() {
103 true
104 } else {
105 is_authenticated
106 };
107
108 let state = cx.new(|_| ExtensionLlmProviderState {
109 is_authenticated,
110 available_models: models,
111 allowed_env_vars,
112 env_var_name_used,
113 });
114
115 Self {
116 extension,
117 provider_info,
118 icon_path,
119 auth_config,
120 state,
121 }
122 }
123
124 fn provider_id_string(&self) -> String {
125 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
126 }
127
128 /// The credential key used for storing the API key in the system keychain.
129 fn credential_key(&self) -> String {
130 format!("extension-llm-{}", self.provider_id_string())
131 }
132}
133
134impl LanguageModelProvider for ExtensionLanguageModelProvider {
135 fn id(&self) -> LanguageModelProviderId {
136 LanguageModelProviderId::from(self.provider_id_string())
137 }
138
139 fn name(&self) -> LanguageModelProviderName {
140 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
141 }
142
143 fn icon(&self) -> ui::IconName {
144 ui::IconName::ZedAssistant
145 }
146
147 fn icon_path(&self) -> Option<SharedString> {
148 self.icon_path.clone()
149 }
150
151 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
152 let state = self.state.read(cx);
153 state
154 .available_models
155 .iter()
156 .find(|m| m.is_default)
157 .or_else(|| state.available_models.first())
158 .map(|model_info| {
159 Arc::new(ExtensionLanguageModel {
160 extension: self.extension.clone(),
161 model_info: model_info.clone(),
162 provider_id: self.id(),
163 provider_name: self.name(),
164 provider_info: self.provider_info.clone(),
165 }) as Arc<dyn LanguageModel>
166 })
167 }
168
169 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
170 let state = self.state.read(cx);
171 state
172 .available_models
173 .iter()
174 .find(|m| m.is_default_fast)
175 .map(|model_info| {
176 Arc::new(ExtensionLanguageModel {
177 extension: self.extension.clone(),
178 model_info: model_info.clone(),
179 provider_id: self.id(),
180 provider_name: self.name(),
181 provider_info: self.provider_info.clone(),
182 }) as Arc<dyn LanguageModel>
183 })
184 }
185
186 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
187 let state = self.state.read(cx);
188 state
189 .available_models
190 .iter()
191 .map(|model_info| {
192 Arc::new(ExtensionLanguageModel {
193 extension: self.extension.clone(),
194 model_info: model_info.clone(),
195 provider_id: self.id(),
196 provider_name: self.name(),
197 provider_info: self.provider_info.clone(),
198 }) as Arc<dyn LanguageModel>
199 })
200 .collect()
201 }
202
203 fn is_authenticated(&self, cx: &App) -> bool {
204 // First check cached state
205 if self.state.read(cx).is_authenticated {
206 return true;
207 }
208
209 // Also check env var dynamically (in case settings changed after provider creation)
210 if let Some(ref auth_config) = self.auth_config {
211 if let Some(ref env_vars) = auth_config.env_vars {
212 let provider_id_string = self.provider_id_string();
213 let settings = ExtensionSettings::get_global(cx);
214 let is_legacy_extension =
215 LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
216
217 for env_var_name in env_vars {
218 let key = format!("{}:{}", provider_id_string, env_var_name);
219 // For legacy extensions, auto-allow if env var is set
220 let env_var_is_set = std::env::var(env_var_name)
221 .map(|v| !v.is_empty())
222 .unwrap_or(false);
223 if settings.allowed_env_var_providers.contains(key.as_str())
224 || (is_legacy_extension && env_var_is_set)
225 {
226 if let Ok(value) = std::env::var(env_var_name) {
227 if !value.is_empty() {
228 return true;
229 }
230 }
231 }
232 }
233 }
234 }
235
236 false
237 }
238
239 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
240 // Check if already authenticated via is_authenticated
241 if self.is_authenticated(cx) {
242 return Task::ready(Ok(()));
243 }
244
245 // Not authenticated - return error indicating credentials not found
246 Task::ready(Err(AuthenticateError::CredentialsNotFound))
247 }
248
249 fn configuration_view(
250 &self,
251 _target_agent: ConfigurationViewTargetAgent,
252 window: &mut Window,
253 cx: &mut App,
254 ) -> AnyView {
255 let credential_key = self.credential_key();
256 let extension = self.extension.clone();
257 let extension_provider_id = self.provider_info.id.clone();
258 let full_provider_id = self.provider_id_string();
259 let state = self.state.clone();
260 let auth_config = self.auth_config.clone();
261
262 cx.new(|cx| {
263 ExtensionProviderConfigurationView::new(
264 credential_key,
265 extension,
266 extension_provider_id,
267 full_provider_id,
268 auth_config,
269 state,
270 window,
271 cx,
272 )
273 })
274 .into()
275 }
276
277 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
278 let extension = self.extension.clone();
279 let provider_id = self.provider_info.id.clone();
280 let state = self.state.clone();
281 let credential_key = self.credential_key();
282
283 let credentials_provider = <dyn CredentialsProvider>::global(cx);
284
285 cx.spawn(async move |cx| {
286 // Delete from system keychain
287 credentials_provider
288 .delete_credentials(&credential_key, cx)
289 .await
290 .log_err();
291
292 // Call extension's reset_credentials
293 let result = extension
294 .call(|extension, store| {
295 async move {
296 extension
297 .call_llm_provider_reset_credentials(store, &provider_id)
298 .await
299 }
300 .boxed()
301 })
302 .await;
303
304 // Update state
305 cx.update(|cx| {
306 state.update(cx, |state, _| {
307 state.is_authenticated = false;
308 });
309 })?;
310
311 match result {
312 Ok(Ok(Ok(()))) => Ok(()),
313 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
314 Ok(Err(e)) => Err(e),
315 Err(e) => Err(e),
316 }
317 })
318 }
319}
320
321impl LanguageModelProviderState for ExtensionLanguageModelProvider {
322 type ObservableEntity = ExtensionLlmProviderState;
323
324 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
325 Some(self.state.clone())
326 }
327
328 fn subscribe<T: 'static>(
329 &self,
330 cx: &mut Context<T>,
331 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
332 ) -> Option<Subscription> {
333 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
334 }
335}
336
337/// Configuration view for extension-based LLM providers.
338struct ExtensionProviderConfigurationView {
339 credential_key: String,
340 extension: WasmExtension,
341 extension_provider_id: String,
342 full_provider_id: String,
343 auth_config: Option<LanguageModelAuthConfig>,
344 state: Entity<ExtensionLlmProviderState>,
345 settings_markdown: Option<Entity<Markdown>>,
346 api_key_editor: Entity<Editor>,
347 loading_settings: bool,
348 loading_credentials: bool,
349 oauth_in_progress: bool,
350 oauth_error: Option<String>,
351 device_user_code: Option<String>,
352 _subscriptions: Vec<Subscription>,
353}
354
355impl ExtensionProviderConfigurationView {
356 fn new(
357 credential_key: String,
358 extension: WasmExtension,
359 extension_provider_id: String,
360 full_provider_id: String,
361 auth_config: Option<LanguageModelAuthConfig>,
362 state: Entity<ExtensionLlmProviderState>,
363 window: &mut Window,
364 cx: &mut Context<Self>,
365 ) -> Self {
366 // Subscribe to state changes
367 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
368 cx.notify();
369 });
370
371 // Create API key editor
372 let api_key_editor = cx.new(|cx| {
373 let mut editor = Editor::single_line(window, cx);
374 editor.set_placeholder_text("Enter API key...", window, cx);
375 editor
376 });
377
378 let mut this = Self {
379 credential_key,
380 extension,
381 extension_provider_id,
382 full_provider_id,
383 auth_config,
384 state,
385 settings_markdown: None,
386 api_key_editor,
387 loading_settings: true,
388 loading_credentials: true,
389 oauth_in_progress: false,
390 oauth_error: None,
391 device_user_code: None,
392 _subscriptions: vec![state_subscription],
393 };
394
395 // Load settings text from extension
396 this.load_settings_text(cx);
397
398 // Load existing credentials
399 this.load_credentials(cx);
400
401 this
402 }
403
404 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
405 let extension = self.extension.clone();
406 let provider_id = self.extension_provider_id.clone();
407
408 cx.spawn(async move |this, cx| {
409 let result = extension
410 .call({
411 let provider_id = provider_id.clone();
412 |ext, store| {
413 async move {
414 ext.call_llm_provider_settings_markdown(store, &provider_id)
415 .await
416 }
417 .boxed()
418 }
419 })
420 .await;
421
422 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
423
424 this.update(cx, |this, cx| {
425 this.loading_settings = false;
426 if let Some(text) = settings_text {
427 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
428 this.settings_markdown = Some(markdown);
429 }
430 cx.notify();
431 })
432 .log_err();
433 })
434 .detach();
435 }
436
437 fn load_credentials(&mut self, cx: &mut Context<Self>) {
438 let credential_key = self.credential_key.clone();
439 let credentials_provider = <dyn CredentialsProvider>::global(cx);
440 let state = self.state.clone();
441
442 // Check if we should use env var (already set in state during provider construction)
443 let using_env_var = self.state.read(cx).env_var_name_used.is_some();
444
445 cx.spawn(async move |this, cx| {
446 // If using env var, we're already authenticated
447 if using_env_var {
448 this.update(cx, |this, cx| {
449 this.loading_credentials = false;
450 cx.notify();
451 })
452 .log_err();
453 return;
454 }
455
456 let credentials = credentials_provider
457 .read_credentials(&credential_key, cx)
458 .await
459 .log_err()
460 .flatten();
461
462 let has_credentials = credentials.is_some();
463
464 // Update authentication state based on stored credentials
465 cx.update(|cx| {
466 state.update(cx, |state, cx| {
467 state.is_authenticated = has_credentials;
468 cx.notify();
469 });
470 })
471 .log_err();
472
473 this.update(cx, |this, cx| {
474 this.loading_credentials = false;
475 cx.notify();
476 })
477 .log_err();
478 })
479 .detach();
480 }
481
482 fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
483 let full_provider_id = self.full_provider_id.clone();
484 let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
485
486 let state = self.state.clone();
487 let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
488
489 // Update settings file
490 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
491 let settings_key = settings_key.clone();
492 move |settings, _| {
493 let allowed = settings
494 .extension
495 .allowed_env_var_providers
496 .get_or_insert_with(Vec::new);
497
498 if currently_allowed {
499 allowed.retain(|id| id.as_ref() != settings_key.as_ref());
500 } else {
501 if !allowed
502 .iter()
503 .any(|id| id.as_ref() == settings_key.as_ref())
504 {
505 allowed.push(settings_key.clone());
506 }
507 }
508 }
509 });
510
511 // Update local state
512 let new_allowed = !currently_allowed;
513 let env_var_name_clone = env_var_name.clone();
514
515 state.update(cx, |state, cx| {
516 if new_allowed {
517 state.allowed_env_vars.insert(env_var_name_clone.clone());
518 // Check if this env var is set and update env_var_name_used
519 if let Ok(value) = std::env::var(&env_var_name_clone) {
520 if !value.is_empty() && state.env_var_name_used.is_none() {
521 state.env_var_name_used = Some(env_var_name_clone);
522 state.is_authenticated = true;
523 }
524 }
525 } else {
526 state.allowed_env_vars.remove(&env_var_name_clone);
527 // If this was the env var being used, clear it and find another
528 if state.env_var_name_used.as_ref() == Some(&env_var_name_clone) {
529 state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
530 if let Ok(value) = std::env::var(var) {
531 if !value.is_empty() {
532 return Some(var.clone());
533 }
534 }
535 None
536 });
537 if state.env_var_name_used.is_none() {
538 // No env var auth available, need to check keychain
539 state.is_authenticated = false;
540 }
541 }
542 }
543 cx.notify();
544 });
545
546 // If all env vars are being disabled, reload credentials from keychain
547 if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
548 self.reload_keychain_credentials(cx);
549 }
550
551 cx.notify();
552 }
553
554 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
555 let credential_key = self.credential_key.clone();
556 let credentials_provider = <dyn CredentialsProvider>::global(cx);
557 let state = self.state.clone();
558
559 cx.spawn(async move |_this, cx| {
560 let credentials = credentials_provider
561 .read_credentials(&credential_key, cx)
562 .await
563 .log_err()
564 .flatten();
565
566 let has_credentials = credentials.is_some();
567
568 cx.update(|cx| {
569 state.update(cx, |state, cx| {
570 state.is_authenticated = has_credentials;
571 cx.notify();
572 });
573 })
574 .log_err();
575 })
576 .detach();
577 }
578
579 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
580 let api_key = self.api_key_editor.read(cx).text(cx);
581 if api_key.is_empty() {
582 return;
583 }
584
585 // Clear the editor
586 self.api_key_editor
587 .update(cx, |editor, cx| editor.set_text("", window, cx));
588
589 let credential_key = self.credential_key.clone();
590 let credentials_provider = <dyn CredentialsProvider>::global(cx);
591 let state = self.state.clone();
592
593 cx.spawn(async move |_this, cx| {
594 // Store in system keychain
595 credentials_provider
596 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
597 .await
598 .log_err();
599
600 // Update state to authenticated
601 cx.update(|cx| {
602 state.update(cx, |state, cx| {
603 state.is_authenticated = true;
604 cx.notify();
605 });
606 })
607 .log_err();
608 })
609 .detach();
610 }
611
612 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
613 // Clear the editor
614 self.api_key_editor
615 .update(cx, |editor, cx| editor.set_text("", window, cx));
616
617 let credential_key = self.credential_key.clone();
618 let credentials_provider = <dyn CredentialsProvider>::global(cx);
619 let state = self.state.clone();
620
621 cx.spawn(async move |_this, cx| {
622 // Delete from system keychain
623 credentials_provider
624 .delete_credentials(&credential_key, cx)
625 .await
626 .log_err();
627
628 // Update state to unauthenticated
629 cx.update(|cx| {
630 state.update(cx, |state, cx| {
631 state.is_authenticated = false;
632 cx.notify();
633 });
634 })
635 .log_err();
636 })
637 .detach();
638 }
639
640 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
641 if self.oauth_in_progress {
642 return;
643 }
644
645 self.oauth_in_progress = true;
646 self.oauth_error = None;
647 self.device_user_code = None;
648 cx.notify();
649
650 let extension = self.extension.clone();
651 let provider_id = self.extension_provider_id.clone();
652 let state = self.state.clone();
653
654 cx.spawn(async move |this, cx| {
655 // Step 1: Start device flow - opens browser and returns user code
656 let start_result = extension
657 .call({
658 let provider_id = provider_id.clone();
659 |ext, store| {
660 async move {
661 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
662 .await
663 }
664 .boxed()
665 }
666 })
667 .await;
668
669 let user_code = match start_result {
670 Ok(Ok(Ok(code))) => code,
671 Ok(Ok(Err(e))) => {
672 log::error!("Device flow start failed: {}", e);
673 this.update(cx, |this, cx| {
674 this.oauth_in_progress = false;
675 this.oauth_error = Some(e);
676 cx.notify();
677 })
678 .log_err();
679 return;
680 }
681 Ok(Err(e)) | Err(e) => {
682 log::error!("Device flow start error: {}", e);
683 this.update(cx, |this, cx| {
684 this.oauth_in_progress = false;
685 this.oauth_error = Some(e.to_string());
686 cx.notify();
687 })
688 .log_err();
689 return;
690 }
691 };
692
693 // Update UI to show the user code before polling
694 this.update(cx, |this, cx| {
695 this.device_user_code = Some(user_code);
696 cx.notify();
697 })
698 .log_err();
699
700 // Step 2: Poll for authentication completion
701 let poll_result = extension
702 .call({
703 let provider_id = provider_id.clone();
704 |ext, store| {
705 async move {
706 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
707 .await
708 }
709 .boxed()
710 }
711 })
712 .await;
713
714 let error_message = match poll_result {
715 Ok(Ok(Ok(()))) => {
716 // After successful auth, refresh the models list
717 let models_result = extension
718 .call({
719 let provider_id = provider_id.clone();
720 |ext, store| {
721 async move {
722 ext.call_llm_provider_models(store, &provider_id).await
723 }
724 .boxed()
725 }
726 })
727 .await;
728
729 let new_models: Vec<LlmModelInfo> = match models_result {
730 Ok(Ok(Ok(models))) => models,
731 _ => Vec::new(),
732 };
733
734 cx.update(|cx| {
735 state.update(cx, |state, cx| {
736 state.is_authenticated = true;
737 state.available_models = new_models;
738 cx.notify();
739 });
740 })
741 .log_err();
742 None
743 }
744 Ok(Ok(Err(e))) => {
745 log::error!("Device flow poll failed: {}", e);
746 Some(e)
747 }
748 Ok(Err(e)) | Err(e) => {
749 log::error!("Device flow poll error: {}", e);
750 Some(e.to_string())
751 }
752 };
753
754 this.update(cx, |this, cx| {
755 this.oauth_in_progress = false;
756 this.oauth_error = error_message;
757 this.device_user_code = None;
758 cx.notify();
759 })
760 .log_err();
761 })
762 .detach();
763 }
764
765 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
766 self.state.read(cx).is_authenticated
767 }
768
769 fn has_oauth_config(&self) -> bool {
770 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
771 }
772
773 fn oauth_config(&self) -> Option<&OAuthConfig> {
774 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
775 }
776
777 fn has_api_key_config(&self) -> bool {
778 // API key is available if there's a credential_label or no oauth-only config
779 self.auth_config
780 .as_ref()
781 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
782 .unwrap_or(true)
783 }
784}
785
786impl gpui::Render for ExtensionProviderConfigurationView {
787 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
788 let is_loading = self.loading_settings || self.loading_credentials;
789 let is_authenticated = self.is_authenticated(cx);
790 let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
791 let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
792 let has_oauth = self.has_oauth_config();
793 let has_api_key = self.has_api_key_config();
794
795 if is_loading {
796 return v_flex()
797 .gap_2()
798 .child(Label::new("Loading...").color(Color::Muted))
799 .into_any_element();
800 }
801
802 let mut content = v_flex().gap_4().size_full();
803
804 // Render settings markdown if available
805 if let Some(markdown) = &self.settings_markdown {
806 let style = settings_markdown_style(_window, cx);
807 content = content.child(MarkdownElement::new(markdown.clone(), style));
808 }
809
810 // Render env var checkboxes - one for each env var the extension declares
811 if let Some(auth_config) = &self.auth_config {
812 if let Some(env_vars) = &auth_config.env_vars {
813 for env_var_name in env_vars {
814 let is_allowed = allowed_env_vars.contains(env_var_name);
815 let checkbox_label =
816 format!("Read API key from {} environment variable", env_var_name);
817 let env_var_for_click = env_var_name.clone();
818
819 content = content.child(
820 h_flex()
821 .gap_2()
822 .child(
823 ui::Checkbox::new(
824 SharedString::from(format!("env-var-{}", env_var_name)),
825 is_allowed.into(),
826 )
827 .on_click(cx.listener(
828 move |this, _, _window, cx| {
829 this.toggle_env_var_permission(
830 env_var_for_click.clone(),
831 cx,
832 );
833 },
834 )),
835 )
836 .child(Label::new(checkbox_label).size(LabelSize::Small)),
837 );
838 }
839
840 // Show status if any env var is being used
841 if let Some(used_var) = &env_var_name_used {
842 let tooltip_label = format!(
843 "To reset this API key, unset the {} environment variable.",
844 used_var
845 );
846 content = content.child(
847 h_flex()
848 .mt_0p5()
849 .p_1()
850 .justify_between()
851 .rounded_md()
852 .border_1()
853 .border_color(cx.theme().colors().border)
854 .bg(cx.theme().colors().background)
855 .child(
856 h_flex()
857 .flex_1()
858 .min_w_0()
859 .gap_1()
860 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
861 .child(
862 Label::new(format!(
863 "API key set in {} environment variable",
864 used_var
865 ))
866 .truncate(),
867 ),
868 )
869 .child(
870 ui::Button::new("reset-key", "Reset Key")
871 .label_size(LabelSize::Small)
872 .icon(ui::IconName::Undo)
873 .icon_size(ui::IconSize::Small)
874 .icon_color(Color::Muted)
875 .icon_position(ui::IconPosition::Start)
876 .disabled(true)
877 .tooltip(ui::Tooltip::text(tooltip_label)),
878 ),
879 );
880 return content.into_any_element();
881 }
882 }
883 }
884
885 // If authenticated, show success state with sign out option
886 if is_authenticated && env_var_name_used.is_none() {
887 let reset_label = if has_oauth && !has_api_key {
888 "Sign Out"
889 } else {
890 "Reset Key"
891 };
892
893 let status_label = if has_oauth && !has_api_key {
894 "Signed in"
895 } else {
896 "API key configured"
897 };
898
899 content = content.child(
900 h_flex()
901 .mt_0p5()
902 .p_1()
903 .justify_between()
904 .rounded_md()
905 .border_1()
906 .border_color(cx.theme().colors().border)
907 .bg(cx.theme().colors().background)
908 .child(
909 h_flex()
910 .flex_1()
911 .min_w_0()
912 .gap_1()
913 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
914 .child(Label::new(status_label).truncate()),
915 )
916 .child(
917 ui::Button::new("reset-key", reset_label)
918 .label_size(LabelSize::Small)
919 .icon(ui::IconName::Undo)
920 .icon_size(ui::IconSize::Small)
921 .icon_color(Color::Muted)
922 .icon_position(ui::IconPosition::Start)
923 .on_click(cx.listener(|this, _, window, cx| {
924 this.reset_api_key(window, cx);
925 })),
926 ),
927 );
928
929 return content.into_any_element();
930 }
931
932 // Not authenticated - show available auth options
933 if env_var_name_used.is_none() {
934 // Render OAuth sign-in button if configured
935 if has_oauth {
936 let oauth_config = self.oauth_config();
937 let button_label = oauth_config
938 .and_then(|c| c.sign_in_button_label.clone())
939 .unwrap_or_else(|| "Sign In".to_string());
940 let button_icon = oauth_config
941 .and_then(|c| c.sign_in_button_icon.as_ref())
942 .and_then(|icon_name| match icon_name.as_str() {
943 "github" => Some(ui::IconName::Github),
944 _ => None,
945 });
946
947 let oauth_in_progress = self.oauth_in_progress;
948
949 let oauth_error = self.oauth_error.clone();
950
951 content = content.child(
952 v_flex()
953 .gap_2()
954 .child({
955 let mut button = ui::Button::new("oauth-sign-in", button_label)
956 .full_width()
957 .style(ui::ButtonStyle::Outlined)
958 .disabled(oauth_in_progress)
959 .on_click(cx.listener(|this, _, _window, cx| {
960 this.start_oauth_sign_in(cx);
961 }));
962 if let Some(icon) = button_icon {
963 button = button
964 .icon(icon)
965 .icon_position(ui::IconPosition::Start)
966 .icon_size(ui::IconSize::Small)
967 .icon_color(Color::Muted);
968 }
969 button
970 })
971 .when(oauth_in_progress, |this| {
972 let user_code = self.device_user_code.clone();
973 this.child(
974 v_flex()
975 .gap_1()
976 .when_some(user_code, |this, code| {
977 let copied = cx
978 .read_from_clipboard()
979 .map(|item| item.text().as_ref() == Some(&code))
980 .unwrap_or(false);
981 let code_for_click = code.clone();
982 this.child(
983 h_flex()
984 .gap_1()
985 .child(
986 Label::new("Enter code:")
987 .size(LabelSize::Small)
988 .color(Color::Muted),
989 )
990 .child(
991 h_flex()
992 .gap_1()
993 .px_1()
994 .border_1()
995 .border_color(cx.theme().colors().border)
996 .rounded_sm()
997 .cursor_pointer()
998 .on_mouse_down(
999 MouseButton::Left,
1000 move |_, window, cx| {
1001 cx.write_to_clipboard(
1002 ClipboardItem::new_string(
1003 code_for_click.clone(),
1004 ),
1005 );
1006 window.refresh();
1007 },
1008 )
1009 .child(
1010 Label::new(code)
1011 .size(LabelSize::Small)
1012 .color(Color::Accent),
1013 )
1014 .child(
1015 ui::Icon::new(if copied {
1016 ui::IconName::Check
1017 } else {
1018 ui::IconName::Copy
1019 })
1020 .size(ui::IconSize::Small)
1021 .color(if copied {
1022 Color::Success
1023 } else {
1024 Color::Muted
1025 }),
1026 ),
1027 ),
1028 )
1029 })
1030 .child(
1031 Label::new("Waiting for authorization in browser...")
1032 .size(LabelSize::Small)
1033 .color(Color::Muted),
1034 ),
1035 )
1036 })
1037 .when_some(oauth_error, |this, error| {
1038 this.child(
1039 v_flex()
1040 .gap_1()
1041 .child(
1042 h_flex()
1043 .gap_2()
1044 .child(
1045 ui::Icon::new(ui::IconName::Warning)
1046 .color(Color::Error)
1047 .size(ui::IconSize::Small),
1048 )
1049 .child(
1050 Label::new("Authentication failed")
1051 .color(Color::Error)
1052 .size(LabelSize::Small),
1053 ),
1054 )
1055 .child(
1056 div().pl_6().child(
1057 Label::new(error)
1058 .color(Color::Error)
1059 .size(LabelSize::Small),
1060 ),
1061 ),
1062 )
1063 }),
1064 );
1065 }
1066
1067 // Render API key input if configured (and we have both options, show a separator)
1068 if has_api_key {
1069 if has_oauth {
1070 content = content.child(
1071 h_flex()
1072 .gap_2()
1073 .items_center()
1074 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1075 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1076 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1077 );
1078 }
1079
1080 let credential_label = self
1081 .auth_config
1082 .as_ref()
1083 .and_then(|c| c.credential_label.clone())
1084 .unwrap_or_else(|| "API Key".to_string());
1085
1086 content = content.child(
1087 v_flex()
1088 .gap_2()
1089 .on_action(cx.listener(Self::save_api_key))
1090 .child(
1091 Label::new(credential_label)
1092 .size(LabelSize::Small)
1093 .color(Color::Muted),
1094 )
1095 .child(self.api_key_editor.clone())
1096 .child(
1097 Label::new("Enter your API key and press Enter to save")
1098 .size(LabelSize::Small)
1099 .color(Color::Muted),
1100 ),
1101 );
1102 }
1103 }
1104
1105 content.into_any_element()
1106 }
1107}
1108
1109impl Focusable for ExtensionProviderConfigurationView {
1110 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1111 self.api_key_editor.focus_handle(cx)
1112 }
1113}
1114
1115fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1116 let theme_settings = ThemeSettings::get_global(cx);
1117 let colors = cx.theme().colors();
1118 let mut text_style = window.text_style();
1119 text_style.refine(&TextStyleRefinement {
1120 font_family: Some(theme_settings.ui_font.family.clone()),
1121 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1122 font_features: Some(theme_settings.ui_font.features.clone()),
1123 color: Some(colors.text),
1124 ..Default::default()
1125 });
1126
1127 MarkdownStyle {
1128 base_text_style: text_style,
1129 selection_background_color: colors.element_selection_background,
1130 inline_code: TextStyleRefinement {
1131 background_color: Some(colors.editor_background),
1132 ..Default::default()
1133 },
1134 link: TextStyleRefinement {
1135 color: Some(colors.text_accent),
1136 underline: Some(UnderlineStyle {
1137 color: Some(colors.text_accent.opacity(0.5)),
1138 thickness: px(1.),
1139 ..Default::default()
1140 }),
1141 ..Default::default()
1142 },
1143 syntax: cx.theme().syntax().clone(),
1144 ..Default::default()
1145 }
1146}
1147
1148/// An extension-based language model.
1149pub struct ExtensionLanguageModel {
1150 extension: WasmExtension,
1151 model_info: LlmModelInfo,
1152 provider_id: LanguageModelProviderId,
1153 provider_name: LanguageModelProviderName,
1154 provider_info: LlmProviderInfo,
1155}
1156
1157impl LanguageModel for ExtensionLanguageModel {
1158 fn id(&self) -> LanguageModelId {
1159 LanguageModelId::from(self.model_info.id.clone())
1160 }
1161
1162 fn name(&self) -> LanguageModelName {
1163 LanguageModelName::from(self.model_info.name.clone())
1164 }
1165
1166 fn provider_id(&self) -> LanguageModelProviderId {
1167 self.provider_id.clone()
1168 }
1169
1170 fn provider_name(&self) -> LanguageModelProviderName {
1171 self.provider_name.clone()
1172 }
1173
1174 fn telemetry_id(&self) -> String {
1175 format!("extension-{}", self.model_info.id)
1176 }
1177
1178 fn supports_images(&self) -> bool {
1179 self.model_info.capabilities.supports_images
1180 }
1181
1182 fn supports_tools(&self) -> bool {
1183 self.model_info.capabilities.supports_tools
1184 }
1185
1186 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1187 match choice {
1188 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1189 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1190 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1191 }
1192 }
1193
1194 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1195 match self.model_info.capabilities.tool_input_format {
1196 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1197 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1198 }
1199 }
1200
1201 fn max_token_count(&self) -> u64 {
1202 self.model_info.max_token_count
1203 }
1204
1205 fn max_output_tokens(&self) -> Option<u64> {
1206 self.model_info.max_output_tokens
1207 }
1208
1209 fn count_tokens(
1210 &self,
1211 request: LanguageModelRequest,
1212 cx: &App,
1213 ) -> BoxFuture<'static, Result<u64>> {
1214 let extension = self.extension.clone();
1215 let provider_id = self.provider_info.id.clone();
1216 let model_id = self.model_info.id.clone();
1217
1218 let wit_request = convert_request_to_wit(request);
1219
1220 cx.background_spawn(async move {
1221 extension
1222 .call({
1223 let provider_id = provider_id.clone();
1224 let model_id = model_id.clone();
1225 let wit_request = wit_request.clone();
1226 |ext, store| {
1227 async move {
1228 let count = ext
1229 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1230 .await?
1231 .map_err(|e| anyhow!("{}", e))?;
1232 Ok(count)
1233 }
1234 .boxed()
1235 }
1236 })
1237 .await?
1238 })
1239 .boxed()
1240 }
1241
1242 fn stream_completion(
1243 &self,
1244 request: LanguageModelRequest,
1245 _cx: &AsyncApp,
1246 ) -> BoxFuture<
1247 'static,
1248 Result<
1249 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1250 LanguageModelCompletionError,
1251 >,
1252 > {
1253 let extension = self.extension.clone();
1254 let provider_id = self.provider_info.id.clone();
1255 let model_id = self.model_info.id.clone();
1256
1257 let wit_request = convert_request_to_wit(request);
1258
1259 async move {
1260 // Start the stream
1261 let stream_id_result = extension
1262 .call({
1263 let provider_id = provider_id.clone();
1264 let model_id = model_id.clone();
1265 let wit_request = wit_request.clone();
1266 |ext, store| {
1267 async move {
1268 let id = ext
1269 .call_llm_stream_completion_start(
1270 store,
1271 &provider_id,
1272 &model_id,
1273 &wit_request,
1274 )
1275 .await?
1276 .map_err(|e| anyhow!("{}", e))?;
1277 Ok(id)
1278 }
1279 .boxed()
1280 }
1281 })
1282 .await;
1283
1284 let stream_id = stream_id_result
1285 .map_err(LanguageModelCompletionError::Other)?
1286 .map_err(LanguageModelCompletionError::Other)?;
1287
1288 // Create a stream that polls for events
1289 let stream = futures::stream::unfold(
1290 (extension.clone(), stream_id, false),
1291 move |(extension, stream_id, done)| async move {
1292 if done {
1293 return None;
1294 }
1295
1296 let result = extension
1297 .call({
1298 let stream_id = stream_id.clone();
1299 |ext, store| {
1300 async move {
1301 let event = ext
1302 .call_llm_stream_completion_next(store, &stream_id)
1303 .await?
1304 .map_err(|e| anyhow!("{}", e))?;
1305 Ok(event)
1306 }
1307 .boxed()
1308 }
1309 })
1310 .await
1311 .and_then(|inner| inner);
1312
1313 match result {
1314 Ok(Some(event)) => {
1315 let converted = convert_completion_event(event);
1316 let is_done =
1317 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1318 Some((converted, (extension, stream_id, is_done)))
1319 }
1320 Ok(None) => {
1321 // Stream complete, close it
1322 let _ = extension
1323 .call({
1324 let stream_id = stream_id.clone();
1325 |ext, store| {
1326 async move {
1327 ext.call_llm_stream_completion_close(store, &stream_id)
1328 .await?;
1329 Ok::<(), anyhow::Error>(())
1330 }
1331 .boxed()
1332 }
1333 })
1334 .await;
1335 None
1336 }
1337 Err(e) => Some((
1338 Err(LanguageModelCompletionError::Other(e)),
1339 (extension, stream_id, true),
1340 )),
1341 }
1342 },
1343 );
1344
1345 Ok(stream.boxed())
1346 }
1347 .boxed()
1348 }
1349
1350 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1351 // Extensions can implement this via llm_cache_configuration
1352 None
1353 }
1354}
1355
1356fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1357 use language_model::{MessageContent, Role};
1358
1359 let messages: Vec<LlmRequestMessage> = request
1360 .messages
1361 .into_iter()
1362 .map(|msg| {
1363 let role = match msg.role {
1364 Role::User => LlmMessageRole::User,
1365 Role::Assistant => LlmMessageRole::Assistant,
1366 Role::System => LlmMessageRole::System,
1367 };
1368
1369 let content: Vec<LlmMessageContent> = msg
1370 .content
1371 .into_iter()
1372 .map(|c| match c {
1373 MessageContent::Text(text) => LlmMessageContent::Text(text),
1374 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1375 source: image.source.to_string(),
1376 width: Some(image.size.width.0 as u32),
1377 height: Some(image.size.height.0 as u32),
1378 }),
1379 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1380 id: tool_use.id.to_string(),
1381 name: tool_use.name.to_string(),
1382 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1383 thought_signature: tool_use.thought_signature,
1384 }),
1385 MessageContent::ToolResult(tool_result) => {
1386 let content = match tool_result.content {
1387 language_model::LanguageModelToolResultContent::Text(text) => {
1388 LlmToolResultContent::Text(text.to_string())
1389 }
1390 language_model::LanguageModelToolResultContent::Image(image) => {
1391 LlmToolResultContent::Image(LlmImageData {
1392 source: image.source.to_string(),
1393 width: Some(image.size.width.0 as u32),
1394 height: Some(image.size.height.0 as u32),
1395 })
1396 }
1397 };
1398 LlmMessageContent::ToolResult(LlmToolResult {
1399 tool_use_id: tool_result.tool_use_id.to_string(),
1400 tool_name: tool_result.tool_name.to_string(),
1401 is_error: tool_result.is_error,
1402 content,
1403 })
1404 }
1405 MessageContent::Thinking { text, signature } => {
1406 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1407 }
1408 MessageContent::RedactedThinking(data) => {
1409 LlmMessageContent::RedactedThinking(data)
1410 }
1411 })
1412 .collect();
1413
1414 LlmRequestMessage {
1415 role,
1416 content,
1417 cache: msg.cache,
1418 }
1419 })
1420 .collect();
1421
1422 let tools: Vec<LlmToolDefinition> = request
1423 .tools
1424 .into_iter()
1425 .map(|tool| LlmToolDefinition {
1426 name: tool.name,
1427 description: tool.description,
1428 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1429 })
1430 .collect();
1431
1432 let tool_choice = request.tool_choice.map(|tc| match tc {
1433 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1434 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1435 LanguageModelToolChoice::None => LlmToolChoice::None,
1436 });
1437
1438 LlmCompletionRequest {
1439 messages,
1440 tools,
1441 tool_choice,
1442 stop_sequences: request.stop,
1443 temperature: request.temperature,
1444 thinking_allowed: false,
1445 max_tokens: None,
1446 }
1447}
1448
1449fn convert_completion_event(
1450 event: LlmCompletionEvent,
1451) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1452 match event {
1453 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1454 message_id: String::new(),
1455 }),
1456 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1457 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1458 text: thinking.text,
1459 signature: thinking.signature,
1460 }),
1461 LlmCompletionEvent::RedactedThinking(data) => {
1462 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1463 }
1464 LlmCompletionEvent::ToolUse(tool_use) => {
1465 let raw_input = tool_use.input.clone();
1466 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1467 Ok(LanguageModelCompletionEvent::ToolUse(
1468 LanguageModelToolUse {
1469 id: LanguageModelToolUseId::from(tool_use.id),
1470 name: tool_use.name.into(),
1471 raw_input,
1472 input,
1473 is_input_complete: true,
1474 thought_signature: tool_use.thought_signature,
1475 },
1476 ))
1477 }
1478 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1479 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1480 id: LanguageModelToolUseId::from(error.id),
1481 tool_name: error.tool_name.into(),
1482 raw_input: error.raw_input.into(),
1483 json_parse_error: error.error,
1484 })
1485 }
1486 LlmCompletionEvent::Stop(reason) => {
1487 let stop_reason = match reason {
1488 LlmStopReason::EndTurn => StopReason::EndTurn,
1489 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1490 LlmStopReason::ToolUse => StopReason::ToolUse,
1491 LlmStopReason::Refusal => StopReason::Refusal,
1492 };
1493 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1494 }
1495 LlmCompletionEvent::Usage(usage) => {
1496 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1497 input_tokens: usage.input_tokens,
1498 output_tokens: usage.output_tokens,
1499 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1500 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1501 }))
1502 }
1503 LlmCompletionEvent::ReasoningDetails(json) => {
1504 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1505 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1506 ))
1507 }
1508 }
1509}