1use crate::ExtensionSettings;
2use crate::LEGACY_LLM_EXTENSION_IDS;
3use crate::wasm_host::WasmExtension;
4use collections::HashSet;
5
6use crate::wasm_host::wit::{
7 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
8 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
9 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
10 LlmToolUse,
11};
12use anyhow::{Result, anyhow};
13use credentials_provider::CredentialsProvider;
14use editor::Editor;
15use extension::{LanguageModelAuthConfig, OAuthConfig};
16use futures::future::BoxFuture;
17use futures::stream::BoxStream;
18use futures::{FutureExt, StreamExt};
19use gpui::Focusable;
20use gpui::{
21 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
22 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
23};
24use language_model::tool_schema::LanguageModelToolSchemaFormat;
25use language_model::{
26 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
27 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
28 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
29 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
30 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
31};
32use markdown::{Markdown, MarkdownElement, MarkdownStyle};
33use settings::Settings;
34use std::sync::Arc;
35use theme::ThemeSettings;
36use ui::{Label, LabelSize, prelude::*};
37use util::ResultExt as _;
38
39/// An extension-based language model provider.
40pub struct ExtensionLanguageModelProvider {
41 pub extension: WasmExtension,
42 pub provider_info: LlmProviderInfo,
43 icon_path: Option<SharedString>,
44 auth_config: Option<LanguageModelAuthConfig>,
45 state: Entity<ExtensionLlmProviderState>,
46}
47
48pub struct ExtensionLlmProviderState {
49 is_authenticated: bool,
50 available_models: Vec<LlmModelInfo>,
51 /// Set of env var names that are allowed to be read for this provider.
52 allowed_env_vars: HashSet<String>,
53 /// If authenticated via env var, which one was used.
54 env_var_name_used: Option<String>,
55}
56
57impl EventEmitter<()> for ExtensionLlmProviderState {}
58
59impl ExtensionLanguageModelProvider {
60 pub fn new(
61 extension: WasmExtension,
62 provider_info: LlmProviderInfo,
63 models: Vec<LlmModelInfo>,
64 is_authenticated: bool,
65 icon_path: Option<SharedString>,
66 auth_config: Option<LanguageModelAuthConfig>,
67 cx: &mut App,
68 ) -> Self {
69 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
70
71 // Build set of allowed env vars for this provider
72 let settings = ExtensionSettings::get_global(cx);
73 let is_legacy_extension =
74 LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
75
76 let mut allowed_env_vars = HashSet::default();
77 if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
78 for env_var_name in env_vars {
79 let key = format!("{}:{}", provider_id_string, env_var_name);
80 // For legacy extensions, auto-allow if env var is set (migration will persist this)
81 let env_var_is_set = std::env::var(env_var_name)
82 .map(|v| !v.is_empty())
83 .unwrap_or(false);
84 if settings.allowed_env_var_providers.contains(key.as_str())
85 || (is_legacy_extension && env_var_is_set)
86 {
87 allowed_env_vars.insert(env_var_name.clone());
88 }
89 }
90 }
91
92 // Check if any allowed env var is set
93 let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
94 if let Ok(value) = std::env::var(env_var_name) {
95 if !value.is_empty() {
96 return Some(env_var_name.clone());
97 }
98 }
99 None
100 });
101
102 let is_authenticated = if env_var_name_used.is_some() {
103 true
104 } else {
105 is_authenticated
106 };
107
108 let state = cx.new(|_| ExtensionLlmProviderState {
109 is_authenticated,
110 available_models: models,
111 allowed_env_vars,
112 env_var_name_used,
113 });
114
115 Self {
116 extension,
117 provider_info,
118 icon_path,
119 auth_config,
120 state,
121 }
122 }
123
124 fn provider_id_string(&self) -> String {
125 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
126 }
127
128 /// The credential key used for storing the API key in the system keychain.
129 fn credential_key(&self) -> String {
130 format!("extension-llm-{}", self.provider_id_string())
131 }
132}
133
134impl LanguageModelProvider for ExtensionLanguageModelProvider {
135 fn id(&self) -> LanguageModelProviderId {
136 LanguageModelProviderId::from(self.provider_id_string())
137 }
138
139 fn name(&self) -> LanguageModelProviderName {
140 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
141 }
142
143 fn icon(&self) -> ui::IconName {
144 ui::IconName::ZedAssistant
145 }
146
147 fn icon_path(&self) -> Option<SharedString> {
148 self.icon_path.clone()
149 }
150
151 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
152 let state = self.state.read(cx);
153 state
154 .available_models
155 .iter()
156 .find(|m| m.is_default)
157 .or_else(|| state.available_models.first())
158 .map(|model_info| {
159 Arc::new(ExtensionLanguageModel {
160 extension: self.extension.clone(),
161 model_info: model_info.clone(),
162 provider_id: self.id(),
163 provider_name: self.name(),
164 provider_info: self.provider_info.clone(),
165 }) as Arc<dyn LanguageModel>
166 })
167 }
168
169 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
170 let state = self.state.read(cx);
171 state
172 .available_models
173 .iter()
174 .find(|m| m.is_default_fast)
175 .map(|model_info| {
176 Arc::new(ExtensionLanguageModel {
177 extension: self.extension.clone(),
178 model_info: model_info.clone(),
179 provider_id: self.id(),
180 provider_name: self.name(),
181 provider_info: self.provider_info.clone(),
182 }) as Arc<dyn LanguageModel>
183 })
184 }
185
186 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
187 let state = self.state.read(cx);
188 state
189 .available_models
190 .iter()
191 .map(|model_info| {
192 Arc::new(ExtensionLanguageModel {
193 extension: self.extension.clone(),
194 model_info: model_info.clone(),
195 provider_id: self.id(),
196 provider_name: self.name(),
197 provider_info: self.provider_info.clone(),
198 }) as Arc<dyn LanguageModel>
199 })
200 .collect()
201 }
202
203 fn is_authenticated(&self, cx: &App) -> bool {
204 // First check cached state
205 if self.state.read(cx).is_authenticated {
206 return true;
207 }
208
209 // Also check env var dynamically (in case settings changed after provider creation)
210 if let Some(ref auth_config) = self.auth_config {
211 if let Some(ref env_vars) = auth_config.env_vars {
212 let provider_id_string = self.provider_id_string();
213 let settings = ExtensionSettings::get_global(cx);
214 let is_legacy_extension =
215 LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
216
217 for env_var_name in env_vars {
218 let key = format!("{}:{}", provider_id_string, env_var_name);
219 // For legacy extensions, auto-allow if env var is set
220 let env_var_is_set = std::env::var(env_var_name)
221 .map(|v| !v.is_empty())
222 .unwrap_or(false);
223 if settings.allowed_env_var_providers.contains(key.as_str())
224 || (is_legacy_extension && env_var_is_set)
225 {
226 if let Ok(value) = std::env::var(env_var_name) {
227 if !value.is_empty() {
228 return true;
229 }
230 }
231 }
232 }
233 }
234 }
235
236 false
237 }
238
239 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
240 // Check if already authenticated via is_authenticated
241 if self.is_authenticated(cx) {
242 return Task::ready(Ok(()));
243 }
244
245 // Not authenticated - return error indicating credentials not found
246 Task::ready(Err(AuthenticateError::CredentialsNotFound))
247 }
248
249 fn configuration_view(
250 &self,
251 _target_agent: ConfigurationViewTargetAgent,
252 window: &mut Window,
253 cx: &mut App,
254 ) -> AnyView {
255 let credential_key = self.credential_key();
256 let extension = self.extension.clone();
257 let extension_provider_id = self.provider_info.id.clone();
258 let full_provider_id = self.provider_id_string();
259 let state = self.state.clone();
260 let auth_config = self.auth_config.clone();
261
262 cx.new(|cx| {
263 ExtensionProviderConfigurationView::new(
264 credential_key,
265 extension,
266 extension_provider_id,
267 full_provider_id,
268 auth_config,
269 state,
270 window,
271 cx,
272 )
273 })
274 .into()
275 }
276
277 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
278 let extension = self.extension.clone();
279 let provider_id = self.provider_info.id.clone();
280 let state = self.state.clone();
281 let credential_key = self.credential_key();
282
283 let credentials_provider = <dyn CredentialsProvider>::global(cx);
284
285 cx.spawn(async move |cx| {
286 // Delete from system keychain
287 credentials_provider
288 .delete_credentials(&credential_key, cx)
289 .await
290 .log_err();
291
292 // Call extension's reset_credentials
293 let result = extension
294 .call(|extension, store| {
295 async move {
296 extension
297 .call_llm_provider_reset_credentials(store, &provider_id)
298 .await
299 }
300 .boxed()
301 })
302 .await;
303
304 // Update state
305 cx.update(|cx| {
306 state.update(cx, |state, _| {
307 state.is_authenticated = false;
308 });
309 })?;
310
311 match result {
312 Ok(Ok(Ok(()))) => Ok(()),
313 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
314 Ok(Err(e)) => Err(e),
315 Err(e) => Err(e),
316 }
317 })
318 }
319}
320
321impl LanguageModelProviderState for ExtensionLanguageModelProvider {
322 type ObservableEntity = ExtensionLlmProviderState;
323
324 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
325 Some(self.state.clone())
326 }
327
328 fn subscribe<T: 'static>(
329 &self,
330 cx: &mut Context<T>,
331 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
332 ) -> Option<Subscription> {
333 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
334 }
335}
336
337/// Configuration view for extension-based LLM providers.
338struct ExtensionProviderConfigurationView {
339 credential_key: String,
340 extension: WasmExtension,
341 extension_provider_id: String,
342 full_provider_id: String,
343 auth_config: Option<LanguageModelAuthConfig>,
344 state: Entity<ExtensionLlmProviderState>,
345 settings_markdown: Option<Entity<Markdown>>,
346 api_key_editor: Entity<Editor>,
347 loading_settings: bool,
348 loading_credentials: bool,
349 oauth_in_progress: bool,
350 oauth_error: Option<String>,
351 device_user_code: Option<String>,
352 _subscriptions: Vec<Subscription>,
353}
354
355impl ExtensionProviderConfigurationView {
356 fn new(
357 credential_key: String,
358 extension: WasmExtension,
359 extension_provider_id: String,
360 full_provider_id: String,
361 auth_config: Option<LanguageModelAuthConfig>,
362 state: Entity<ExtensionLlmProviderState>,
363 window: &mut Window,
364 cx: &mut Context<Self>,
365 ) -> Self {
366 // Subscribe to state changes
367 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
368 cx.notify();
369 });
370
371 // Create API key editor
372 let api_key_editor = cx.new(|cx| {
373 let mut editor = Editor::single_line(window, cx);
374 editor.set_placeholder_text("Enter API key...", window, cx);
375 editor
376 });
377
378 let mut this = Self {
379 credential_key,
380 extension,
381 extension_provider_id,
382 full_provider_id,
383 auth_config,
384 state,
385 settings_markdown: None,
386 api_key_editor,
387 loading_settings: true,
388 loading_credentials: true,
389 oauth_in_progress: false,
390 oauth_error: None,
391 device_user_code: None,
392 _subscriptions: vec![state_subscription],
393 };
394
395 // Load settings text from extension
396 this.load_settings_text(cx);
397
398 // Load existing credentials
399 this.load_credentials(cx);
400
401 this
402 }
403
404 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
405 let extension = self.extension.clone();
406 let provider_id = self.extension_provider_id.clone();
407
408 cx.spawn(async move |this, cx| {
409 let result = extension
410 .call({
411 let provider_id = provider_id.clone();
412 |ext, store| {
413 async move {
414 ext.call_llm_provider_settings_markdown(store, &provider_id)
415 .await
416 }
417 .boxed()
418 }
419 })
420 .await;
421
422 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
423
424 this.update(cx, |this, cx| {
425 this.loading_settings = false;
426 if let Some(text) = settings_text {
427 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
428 this.settings_markdown = Some(markdown);
429 }
430 cx.notify();
431 })
432 .log_err();
433 })
434 .detach();
435 }
436
437 fn load_credentials(&mut self, cx: &mut Context<Self>) {
438 let credential_key = self.credential_key.clone();
439 let credentials_provider = <dyn CredentialsProvider>::global(cx);
440 let state = self.state.clone();
441
442 // Check if we should use env var (already set in state during provider construction)
443 let using_env_var = self.state.read(cx).env_var_name_used.is_some();
444
445 cx.spawn(async move |this, cx| {
446 // If using env var, we're already authenticated
447 if using_env_var {
448 this.update(cx, |this, cx| {
449 this.loading_credentials = false;
450 cx.notify();
451 })
452 .log_err();
453 return;
454 }
455
456 let credentials = credentials_provider
457 .read_credentials(&credential_key, cx)
458 .await
459 .log_err()
460 .flatten();
461
462 let has_credentials = credentials.is_some();
463
464 // Update authentication state based on stored credentials
465 cx.update(|cx| {
466 state.update(cx, |state, cx| {
467 state.is_authenticated = has_credentials;
468 cx.notify();
469 });
470 })
471 .log_err();
472
473 this.update(cx, |this, cx| {
474 this.loading_credentials = false;
475 cx.notify();
476 })
477 .log_err();
478 })
479 .detach();
480 }
481
482 fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
483 let full_provider_id = self.full_provider_id.clone();
484 let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
485
486 let state = self.state.clone();
487 let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
488
489 // Update settings file
490 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
491 let settings_key = settings_key.clone();
492 move |settings, _| {
493 let allowed = settings
494 .extension
495 .allowed_env_var_providers
496 .get_or_insert_with(Vec::new);
497
498 if currently_allowed {
499 allowed.retain(|id| id.as_ref() != settings_key.as_ref());
500 } else {
501 if !allowed
502 .iter()
503 .any(|id| id.as_ref() == settings_key.as_ref())
504 {
505 allowed.push(settings_key.clone());
506 }
507 }
508 }
509 });
510
511 // Update local state
512 let new_allowed = !currently_allowed;
513 let env_var_name_clone = env_var_name.clone();
514
515 state.update(cx, |state, cx| {
516 if new_allowed {
517 state.allowed_env_vars.insert(env_var_name_clone.clone());
518 // Check if this env var is set and update env_var_name_used
519 if let Ok(value) = std::env::var(&env_var_name_clone) {
520 if !value.is_empty() && state.env_var_name_used.is_none() {
521 state.env_var_name_used = Some(env_var_name_clone);
522 state.is_authenticated = true;
523 }
524 }
525 } else {
526 state.allowed_env_vars.remove(&env_var_name_clone);
527 // If this was the env var being used, clear it and find another
528 if state.env_var_name_used.as_ref() == Some(&env_var_name_clone) {
529 state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
530 if let Ok(value) = std::env::var(var) {
531 if !value.is_empty() {
532 return Some(var.clone());
533 }
534 }
535 None
536 });
537 if state.env_var_name_used.is_none() {
538 // No env var auth available, need to check keychain
539 state.is_authenticated = false;
540 }
541 }
542 }
543 cx.notify();
544 });
545
546 // If all env vars are being disabled, reload credentials from keychain
547 if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
548 self.reload_keychain_credentials(cx);
549 }
550
551 cx.notify();
552 }
553
554 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
555 let credential_key = self.credential_key.clone();
556 let credentials_provider = <dyn CredentialsProvider>::global(cx);
557 let state = self.state.clone();
558
559 cx.spawn(async move |_this, cx| {
560 let credentials = credentials_provider
561 .read_credentials(&credential_key, cx)
562 .await
563 .log_err()
564 .flatten();
565
566 let has_credentials = credentials.is_some();
567
568 cx.update(|cx| {
569 state.update(cx, |state, cx| {
570 state.is_authenticated = has_credentials;
571 cx.notify();
572 });
573 })
574 .log_err();
575 })
576 .detach();
577 }
578
579 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
580 let api_key = self.api_key_editor.read(cx).text(cx);
581 if api_key.is_empty() {
582 return;
583 }
584
585 // Clear the editor
586 self.api_key_editor
587 .update(cx, |editor, cx| editor.set_text("", window, cx));
588
589 let credential_key = self.credential_key.clone();
590 let credentials_provider = <dyn CredentialsProvider>::global(cx);
591 let state = self.state.clone();
592
593 cx.spawn(async move |_this, cx| {
594 // Store in system keychain
595 credentials_provider
596 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
597 .await
598 .log_err();
599
600 // Update state to authenticated
601 cx.update(|cx| {
602 state.update(cx, |state, cx| {
603 state.is_authenticated = true;
604 cx.notify();
605 });
606 })
607 .log_err();
608 })
609 .detach();
610 }
611
612 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
613 // Clear the editor
614 self.api_key_editor
615 .update(cx, |editor, cx| editor.set_text("", window, cx));
616
617 let credential_key = self.credential_key.clone();
618 let credentials_provider = <dyn CredentialsProvider>::global(cx);
619 let state = self.state.clone();
620
621 cx.spawn(async move |_this, cx| {
622 // Delete from system keychain
623 credentials_provider
624 .delete_credentials(&credential_key, cx)
625 .await
626 .log_err();
627
628 // Update state to unauthenticated
629 cx.update(|cx| {
630 state.update(cx, |state, cx| {
631 state.is_authenticated = false;
632 cx.notify();
633 });
634 })
635 .log_err();
636 })
637 .detach();
638 }
639
640 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
641 if self.oauth_in_progress {
642 return;
643 }
644
645 self.oauth_in_progress = true;
646 self.oauth_error = None;
647 self.device_user_code = None;
648 cx.notify();
649
650 let extension = self.extension.clone();
651 let provider_id = self.extension_provider_id.clone();
652 let state = self.state.clone();
653
654 cx.spawn(async move |this, cx| {
655 // Step 1: Start device flow - opens browser and returns user code
656 let start_result = extension
657 .call({
658 let provider_id = provider_id.clone();
659 |ext, store| {
660 async move {
661 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
662 .await
663 }
664 .boxed()
665 }
666 })
667 .await;
668
669 let user_code = match start_result {
670 Ok(Ok(Ok(code))) => code,
671 Ok(Ok(Err(e))) => {
672 log::error!("Device flow start failed: {}", e);
673 this.update(cx, |this, cx| {
674 this.oauth_in_progress = false;
675 this.oauth_error = Some(e);
676 cx.notify();
677 })
678 .log_err();
679 return;
680 }
681 Ok(Err(e)) | Err(e) => {
682 log::error!("Device flow start error: {}", e);
683 this.update(cx, |this, cx| {
684 this.oauth_in_progress = false;
685 this.oauth_error = Some(e.to_string());
686 cx.notify();
687 })
688 .log_err();
689 return;
690 }
691 };
692
693 // Update UI to show the user code before polling
694 this.update(cx, |this, cx| {
695 this.device_user_code = Some(user_code);
696 cx.notify();
697 })
698 .log_err();
699
700 // Step 2: Poll for authentication completion
701 let poll_result = extension
702 .call({
703 let provider_id = provider_id.clone();
704 |ext, store| {
705 async move {
706 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
707 .await
708 }
709 .boxed()
710 }
711 })
712 .await;
713
714 let error_message = match poll_result {
715 Ok(Ok(Ok(()))) => {
716 // After successful auth, refresh the models list
717 let models_result = extension
718 .call({
719 let provider_id = provider_id.clone();
720 |ext, store| {
721 async move {
722 ext.call_llm_provider_models(store, &provider_id).await
723 }
724 .boxed()
725 }
726 })
727 .await;
728
729 let new_models: Vec<LlmModelInfo> = match models_result {
730 Ok(Ok(Ok(models))) => models,
731 _ => Vec::new(),
732 };
733
734 cx.update(|cx| {
735 state.update(cx, |state, cx| {
736 state.is_authenticated = true;
737 state.available_models = new_models;
738 cx.notify();
739 });
740 })
741 .log_err();
742 None
743 }
744 Ok(Ok(Err(e))) => {
745 log::error!("Device flow poll failed: {}", e);
746 Some(e)
747 }
748 Ok(Err(e)) | Err(e) => {
749 log::error!("Device flow poll error: {}", e);
750 Some(e.to_string())
751 }
752 };
753
754 this.update(cx, |this, cx| {
755 this.oauth_in_progress = false;
756 this.oauth_error = error_message;
757 this.device_user_code = None;
758 cx.notify();
759 })
760 .log_err();
761 })
762 .detach();
763 }
764
765 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
766 self.state.read(cx).is_authenticated
767 }
768
769 fn has_oauth_config(&self) -> bool {
770 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
771 }
772
773 fn oauth_config(&self) -> Option<&OAuthConfig> {
774 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
775 }
776
777 fn has_api_key_config(&self) -> bool {
778 // API key is available if there's a credential_label or no oauth-only config
779 self.auth_config
780 .as_ref()
781 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
782 .unwrap_or(true)
783 }
784}
785
786impl gpui::Render for ExtensionProviderConfigurationView {
787 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
788 let is_loading = self.loading_settings || self.loading_credentials;
789 let is_authenticated = self.is_authenticated(cx);
790 let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
791 let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
792 let has_oauth = self.has_oauth_config();
793 let has_api_key = self.has_api_key_config();
794
795 if is_loading {
796 return v_flex()
797 .gap_2()
798 .child(Label::new("Loading...").color(Color::Muted))
799 .into_any_element();
800 }
801
802 let mut content = v_flex().gap_4().size_full();
803
804 // Render settings markdown if available
805 if let Some(markdown) = &self.settings_markdown {
806 let style = settings_markdown_style(_window, cx);
807 content = content.child(MarkdownElement::new(markdown.clone(), style));
808 }
809
810 // Render env var checkboxes - one for each env var the extension declares
811 if let Some(auth_config) = &self.auth_config {
812 if let Some(env_vars) = &auth_config.env_vars {
813 for env_var_name in env_vars {
814 let is_allowed = allowed_env_vars.contains(env_var_name);
815 let checkbox_label =
816 format!("Read API key from {} environment variable", env_var_name);
817 let env_var_for_click = env_var_name.clone();
818
819 content = content.child(
820 h_flex()
821 .gap_2()
822 .child(
823 ui::Checkbox::new(
824 SharedString::from(format!("env-var-{}", env_var_name)),
825 is_allowed.into(),
826 )
827 .on_click(cx.listener(
828 move |this, _, _window, cx| {
829 this.toggle_env_var_permission(
830 env_var_for_click.clone(),
831 cx,
832 );
833 },
834 )),
835 )
836 .child(Label::new(checkbox_label).size(LabelSize::Small)),
837 );
838 }
839
840 // Show status if any env var is being used
841 if let Some(used_var) = &env_var_name_used {
842 let tooltip_label = format!(
843 "To reset this API key, unset the {} environment variable.",
844 used_var
845 );
846 content = content.child(
847 h_flex()
848 .mt_0p5()
849 .p_1()
850 .justify_between()
851 .rounded_md()
852 .border_1()
853 .border_color(cx.theme().colors().border)
854 .bg(cx.theme().colors().background)
855 .child(
856 h_flex()
857 .flex_1()
858 .min_w_0()
859 .gap_1()
860 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
861 .child(
862 Label::new(format!(
863 "API key set in {} environment variable",
864 used_var
865 ))
866 .truncate(),
867 ),
868 )
869 .child(
870 ui::Button::new("reset-key", "Reset Key")
871 .label_size(LabelSize::Small)
872 .icon(ui::IconName::Undo)
873 .icon_size(ui::IconSize::Small)
874 .icon_color(Color::Muted)
875 .icon_position(ui::IconPosition::Start)
876 .disabled(true)
877 .tooltip(ui::Tooltip::text(tooltip_label)),
878 ),
879 );
880 return content.into_any_element();
881 }
882 }
883 }
884
885 // If authenticated, show success state with sign out option
886 if is_authenticated && env_var_name_used.is_none() {
887 let reset_label = if has_oauth && !has_api_key {
888 "Sign Out"
889 } else {
890 "Reset Key"
891 };
892
893 let status_label = if has_oauth && !has_api_key {
894 "Signed in"
895 } else {
896 "API key configured"
897 };
898
899 content = content.child(
900 h_flex()
901 .mt_0p5()
902 .p_1()
903 .justify_between()
904 .rounded_md()
905 .border_1()
906 .border_color(cx.theme().colors().border)
907 .bg(cx.theme().colors().background)
908 .child(
909 h_flex()
910 .flex_1()
911 .min_w_0()
912 .gap_1()
913 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
914 .child(Label::new(status_label).truncate()),
915 )
916 .child(
917 ui::Button::new("reset-key", reset_label)
918 .label_size(LabelSize::Small)
919 .icon(ui::IconName::Undo)
920 .icon_size(ui::IconSize::Small)
921 .icon_color(Color::Muted)
922 .icon_position(ui::IconPosition::Start)
923 .on_click(cx.listener(|this, _, window, cx| {
924 this.reset_api_key(window, cx);
925 })),
926 ),
927 );
928
929 return content.into_any_element();
930 }
931
932 // Not authenticated - show available auth options
933 if env_var_name_used.is_none() {
934 // Render OAuth sign-in button if configured
935 if has_oauth {
936 let oauth_config = self.oauth_config();
937 let button_label = oauth_config
938 .and_then(|c| c.sign_in_button_label.clone())
939 .unwrap_or_else(|| "Sign In".to_string());
940 let button_icon = oauth_config
941 .and_then(|c| c.sign_in_button_icon.as_ref())
942 .and_then(|icon_name| match icon_name.as_str() {
943 "github" => Some(ui::IconName::Github),
944 _ => None,
945 });
946
947 let oauth_in_progress = self.oauth_in_progress;
948
949 let oauth_error = self.oauth_error.clone();
950
951 content = content.child(
952 v_flex()
953 .gap_2()
954 .child({
955 let mut button = ui::Button::new("oauth-sign-in", button_label)
956 .full_width()
957 .style(ui::ButtonStyle::Outlined)
958 .disabled(oauth_in_progress)
959 .on_click(cx.listener(|this, _, _window, cx| {
960 this.start_oauth_sign_in(cx);
961 }));
962 if let Some(icon) = button_icon {
963 button = button
964 .icon(icon)
965 .icon_position(ui::IconPosition::Start)
966 .icon_size(ui::IconSize::Small)
967 .icon_color(Color::Muted);
968 }
969 button
970 })
971 .when(oauth_in_progress, |this| {
972 let user_code = self.device_user_code.clone();
973 this.child(
974 v_flex()
975 .gap_1()
976 .when_some(user_code, |this, code| {
977 let copied = cx
978 .read_from_clipboard()
979 .map(|item| item.text().as_ref() == Some(&code))
980 .unwrap_or(false);
981 let code_for_click = code.clone();
982 this.child(
983 h_flex()
984 .gap_1()
985 .child(
986 Label::new("Enter code:")
987 .size(LabelSize::Small)
988 .color(Color::Muted),
989 )
990 .child(
991 h_flex()
992 .gap_1()
993 .px_1()
994 .border_1()
995 .border_color(cx.theme().colors().border)
996 .rounded_sm()
997 .cursor_pointer()
998 .on_mouse_down(
999 MouseButton::Left,
1000 move |_, window, cx| {
1001 cx.write_to_clipboard(
1002 ClipboardItem::new_string(
1003 code_for_click.clone(),
1004 ),
1005 );
1006 window.refresh();
1007 },
1008 )
1009 .child(
1010 Label::new(code)
1011 .size(LabelSize::Small)
1012 .color(Color::Accent),
1013 )
1014 .child(
1015 ui::Icon::new(if copied {
1016 ui::IconName::Check
1017 } else {
1018 ui::IconName::Copy
1019 })
1020 .size(ui::IconSize::Small)
1021 .color(if copied {
1022 Color::Success
1023 } else {
1024 Color::Muted
1025 }),
1026 ),
1027 ),
1028 )
1029 })
1030 .child(
1031 Label::new("Waiting for authorization in browser...")
1032 .size(LabelSize::Small)
1033 .color(Color::Muted),
1034 ),
1035 )
1036 })
1037 .when_some(oauth_error, |this, error| {
1038 this.child(
1039 v_flex()
1040 .gap_1()
1041 .child(
1042 h_flex()
1043 .gap_2()
1044 .child(
1045 ui::Icon::new(ui::IconName::Warning)
1046 .color(Color::Error)
1047 .size(ui::IconSize::Small),
1048 )
1049 .child(
1050 Label::new("Authentication failed")
1051 .color(Color::Error)
1052 .size(LabelSize::Small),
1053 ),
1054 )
1055 .child(
1056 div().pl_6().child(
1057 Label::new(error)
1058 .color(Color::Error)
1059 .size(LabelSize::Small),
1060 ),
1061 ),
1062 )
1063 }),
1064 );
1065 }
1066
1067 // Render API key input if configured (and we have both options, show a separator)
1068 if has_api_key {
1069 if has_oauth {
1070 content = content.child(
1071 h_flex()
1072 .gap_2()
1073 .items_center()
1074 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1075 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1076 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1077 );
1078 }
1079
1080 let credential_label = self
1081 .auth_config
1082 .as_ref()
1083 .and_then(|c| c.credential_label.clone())
1084 .unwrap_or_else(|| "API Key".to_string());
1085
1086 content = content.child(
1087 v_flex()
1088 .gap_2()
1089 .on_action(cx.listener(Self::save_api_key))
1090 .child(
1091 Label::new(credential_label)
1092 .size(LabelSize::Small)
1093 .color(Color::Muted),
1094 )
1095 .child(self.api_key_editor.clone())
1096 .child(
1097 Label::new("Enter your API key and press Enter to save")
1098 .size(LabelSize::Small)
1099 .color(Color::Muted),
1100 ),
1101 );
1102 }
1103 }
1104
1105 // Show OpenAI-compatible models notification for OpenAI extension
1106 if self.extension_provider_id == "openai" {
1107 content = content.child(
1108 h_flex()
1109 .gap_1()
1110 .child(
1111 ui::Icon::new(ui::IconName::Info)
1112 .size(ui::IconSize::Small)
1113 .color(Color::Muted),
1114 )
1115 .child(
1116 Label::new("Zed also supports OpenAI-compatible models.")
1117 .size(LabelSize::Small)
1118 .color(Color::Muted),
1119 )
1120 .child(
1121 ui::Button::new("learn-more", "Learn More")
1122 .style(ui::ButtonStyle::Subtle)
1123 .label_size(LabelSize::Small)
1124 .icon(ui::IconName::ArrowUpRight)
1125 .icon_size(ui::IconSize::Small)
1126 .icon_color(Color::Muted)
1127 .icon_position(ui::IconPosition::End)
1128 .on_click(|_, _, cx| {
1129 cx.open_url("https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers");
1130 }),
1131 ),
1132 );
1133 }
1134
1135 content.into_any_element()
1136 }
1137}
1138
1139impl Focusable for ExtensionProviderConfigurationView {
1140 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1141 self.api_key_editor.focus_handle(cx)
1142 }
1143}
1144
1145fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1146 let theme_settings = ThemeSettings::get_global(cx);
1147 let colors = cx.theme().colors();
1148 let mut text_style = window.text_style();
1149 text_style.refine(&TextStyleRefinement {
1150 font_family: Some(theme_settings.ui_font.family.clone()),
1151 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1152 font_features: Some(theme_settings.ui_font.features.clone()),
1153 color: Some(colors.text),
1154 ..Default::default()
1155 });
1156
1157 MarkdownStyle {
1158 base_text_style: text_style,
1159 selection_background_color: colors.element_selection_background,
1160 inline_code: TextStyleRefinement {
1161 background_color: Some(colors.editor_background),
1162 ..Default::default()
1163 },
1164 link: TextStyleRefinement {
1165 color: Some(colors.text_accent),
1166 underline: Some(UnderlineStyle {
1167 color: Some(colors.text_accent.opacity(0.5)),
1168 thickness: px(1.),
1169 ..Default::default()
1170 }),
1171 ..Default::default()
1172 },
1173 syntax: cx.theme().syntax().clone(),
1174 ..Default::default()
1175 }
1176}
1177
1178/// An extension-based language model.
1179pub struct ExtensionLanguageModel {
1180 extension: WasmExtension,
1181 model_info: LlmModelInfo,
1182 provider_id: LanguageModelProviderId,
1183 provider_name: LanguageModelProviderName,
1184 provider_info: LlmProviderInfo,
1185}
1186
1187impl LanguageModel for ExtensionLanguageModel {
1188 fn id(&self) -> LanguageModelId {
1189 LanguageModelId::from(self.model_info.id.clone())
1190 }
1191
1192 fn name(&self) -> LanguageModelName {
1193 LanguageModelName::from(self.model_info.name.clone())
1194 }
1195
1196 fn provider_id(&self) -> LanguageModelProviderId {
1197 self.provider_id.clone()
1198 }
1199
1200 fn provider_name(&self) -> LanguageModelProviderName {
1201 self.provider_name.clone()
1202 }
1203
1204 fn telemetry_id(&self) -> String {
1205 format!("extension-{}", self.model_info.id)
1206 }
1207
1208 fn supports_images(&self) -> bool {
1209 self.model_info.capabilities.supports_images
1210 }
1211
1212 fn supports_tools(&self) -> bool {
1213 self.model_info.capabilities.supports_tools
1214 }
1215
1216 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1217 match choice {
1218 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1219 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1220 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1221 }
1222 }
1223
1224 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1225 match self.model_info.capabilities.tool_input_format {
1226 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1227 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1228 }
1229 }
1230
1231 fn max_token_count(&self) -> u64 {
1232 self.model_info.max_token_count
1233 }
1234
1235 fn max_output_tokens(&self) -> Option<u64> {
1236 self.model_info.max_output_tokens
1237 }
1238
1239 fn count_tokens(
1240 &self,
1241 request: LanguageModelRequest,
1242 cx: &App,
1243 ) -> BoxFuture<'static, Result<u64>> {
1244 let extension = self.extension.clone();
1245 let provider_id = self.provider_info.id.clone();
1246 let model_id = self.model_info.id.clone();
1247
1248 let wit_request = convert_request_to_wit(request);
1249
1250 cx.background_spawn(async move {
1251 extension
1252 .call({
1253 let provider_id = provider_id.clone();
1254 let model_id = model_id.clone();
1255 let wit_request = wit_request.clone();
1256 |ext, store| {
1257 async move {
1258 let count = ext
1259 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1260 .await?
1261 .map_err(|e| anyhow!("{}", e))?;
1262 Ok(count)
1263 }
1264 .boxed()
1265 }
1266 })
1267 .await?
1268 })
1269 .boxed()
1270 }
1271
1272 fn stream_completion(
1273 &self,
1274 request: LanguageModelRequest,
1275 _cx: &AsyncApp,
1276 ) -> BoxFuture<
1277 'static,
1278 Result<
1279 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1280 LanguageModelCompletionError,
1281 >,
1282 > {
1283 let extension = self.extension.clone();
1284 let provider_id = self.provider_info.id.clone();
1285 let model_id = self.model_info.id.clone();
1286
1287 let wit_request = convert_request_to_wit(request);
1288
1289 async move {
1290 // Start the stream
1291 let stream_id_result = extension
1292 .call({
1293 let provider_id = provider_id.clone();
1294 let model_id = model_id.clone();
1295 let wit_request = wit_request.clone();
1296 |ext, store| {
1297 async move {
1298 let id = ext
1299 .call_llm_stream_completion_start(
1300 store,
1301 &provider_id,
1302 &model_id,
1303 &wit_request,
1304 )
1305 .await?
1306 .map_err(|e| anyhow!("{}", e))?;
1307 Ok(id)
1308 }
1309 .boxed()
1310 }
1311 })
1312 .await;
1313
1314 let stream_id = stream_id_result
1315 .map_err(LanguageModelCompletionError::Other)?
1316 .map_err(LanguageModelCompletionError::Other)?;
1317
1318 // Create a stream that polls for events
1319 let stream = futures::stream::unfold(
1320 (extension.clone(), stream_id, false),
1321 move |(extension, stream_id, done)| async move {
1322 if done {
1323 return None;
1324 }
1325
1326 let result = extension
1327 .call({
1328 let stream_id = stream_id.clone();
1329 |ext, store| {
1330 async move {
1331 let event = ext
1332 .call_llm_stream_completion_next(store, &stream_id)
1333 .await?
1334 .map_err(|e| anyhow!("{}", e))?;
1335 Ok(event)
1336 }
1337 .boxed()
1338 }
1339 })
1340 .await
1341 .and_then(|inner| inner);
1342
1343 match result {
1344 Ok(Some(event)) => {
1345 let converted = convert_completion_event(event);
1346 let is_done =
1347 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1348 Some((converted, (extension, stream_id, is_done)))
1349 }
1350 Ok(None) => {
1351 // Stream complete, close it
1352 let _ = extension
1353 .call({
1354 let stream_id = stream_id.clone();
1355 |ext, store| {
1356 async move {
1357 ext.call_llm_stream_completion_close(store, &stream_id)
1358 .await?;
1359 Ok::<(), anyhow::Error>(())
1360 }
1361 .boxed()
1362 }
1363 })
1364 .await;
1365 None
1366 }
1367 Err(e) => Some((
1368 Err(LanguageModelCompletionError::Other(e)),
1369 (extension, stream_id, true),
1370 )),
1371 }
1372 },
1373 );
1374
1375 Ok(stream.boxed())
1376 }
1377 .boxed()
1378 }
1379
1380 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1381 // Extensions can implement this via llm_cache_configuration
1382 None
1383 }
1384}
1385
1386fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1387 use language_model::{MessageContent, Role};
1388
1389 let messages: Vec<LlmRequestMessage> = request
1390 .messages
1391 .into_iter()
1392 .map(|msg| {
1393 let role = match msg.role {
1394 Role::User => LlmMessageRole::User,
1395 Role::Assistant => LlmMessageRole::Assistant,
1396 Role::System => LlmMessageRole::System,
1397 };
1398
1399 let content: Vec<LlmMessageContent> = msg
1400 .content
1401 .into_iter()
1402 .map(|c| match c {
1403 MessageContent::Text(text) => LlmMessageContent::Text(text),
1404 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1405 source: image.source.to_string(),
1406 width: Some(image.size.width.0 as u32),
1407 height: Some(image.size.height.0 as u32),
1408 }),
1409 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1410 id: tool_use.id.to_string(),
1411 name: tool_use.name.to_string(),
1412 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1413 thought_signature: tool_use.thought_signature,
1414 }),
1415 MessageContent::ToolResult(tool_result) => {
1416 let content = match tool_result.content {
1417 language_model::LanguageModelToolResultContent::Text(text) => {
1418 LlmToolResultContent::Text(text.to_string())
1419 }
1420 language_model::LanguageModelToolResultContent::Image(image) => {
1421 LlmToolResultContent::Image(LlmImageData {
1422 source: image.source.to_string(),
1423 width: Some(image.size.width.0 as u32),
1424 height: Some(image.size.height.0 as u32),
1425 })
1426 }
1427 };
1428 LlmMessageContent::ToolResult(LlmToolResult {
1429 tool_use_id: tool_result.tool_use_id.to_string(),
1430 tool_name: tool_result.tool_name.to_string(),
1431 is_error: tool_result.is_error,
1432 content,
1433 })
1434 }
1435 MessageContent::Thinking { text, signature } => {
1436 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1437 }
1438 MessageContent::RedactedThinking(data) => {
1439 LlmMessageContent::RedactedThinking(data)
1440 }
1441 })
1442 .collect();
1443
1444 LlmRequestMessage {
1445 role,
1446 content,
1447 cache: msg.cache,
1448 }
1449 })
1450 .collect();
1451
1452 let tools: Vec<LlmToolDefinition> = request
1453 .tools
1454 .into_iter()
1455 .map(|tool| LlmToolDefinition {
1456 name: tool.name,
1457 description: tool.description,
1458 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1459 })
1460 .collect();
1461
1462 let tool_choice = request.tool_choice.map(|tc| match tc {
1463 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1464 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1465 LanguageModelToolChoice::None => LlmToolChoice::None,
1466 });
1467
1468 LlmCompletionRequest {
1469 messages,
1470 tools,
1471 tool_choice,
1472 stop_sequences: request.stop,
1473 temperature: request.temperature,
1474 thinking_allowed: false,
1475 max_tokens: None,
1476 }
1477}
1478
1479fn convert_completion_event(
1480 event: LlmCompletionEvent,
1481) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1482 match event {
1483 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1484 message_id: String::new(),
1485 }),
1486 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1487 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1488 text: thinking.text,
1489 signature: thinking.signature,
1490 }),
1491 LlmCompletionEvent::RedactedThinking(data) => {
1492 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1493 }
1494 LlmCompletionEvent::ToolUse(tool_use) => {
1495 let raw_input = tool_use.input.clone();
1496 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1497 Ok(LanguageModelCompletionEvent::ToolUse(
1498 LanguageModelToolUse {
1499 id: LanguageModelToolUseId::from(tool_use.id),
1500 name: tool_use.name.into(),
1501 raw_input,
1502 input,
1503 is_input_complete: true,
1504 thought_signature: tool_use.thought_signature,
1505 },
1506 ))
1507 }
1508 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1509 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1510 id: LanguageModelToolUseId::from(error.id),
1511 tool_name: error.tool_name.into(),
1512 raw_input: error.raw_input.into(),
1513 json_parse_error: error.error,
1514 })
1515 }
1516 LlmCompletionEvent::Stop(reason) => {
1517 let stop_reason = match reason {
1518 LlmStopReason::EndTurn => StopReason::EndTurn,
1519 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1520 LlmStopReason::ToolUse => StopReason::ToolUse,
1521 LlmStopReason::Refusal => StopReason::Refusal,
1522 };
1523 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1524 }
1525 LlmCompletionEvent::Usage(usage) => {
1526 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1527 input_tokens: usage.input_tokens,
1528 output_tokens: usage.output_tokens,
1529 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1530 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1531 }))
1532 }
1533 LlmCompletionEvent::ReasoningDetails(json) => {
1534 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1535 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1536 ))
1537 }
1538 }
1539}