1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::{LanguageModelAuthConfig, OAuthConfig};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
20 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 // First check cached state
183 if self.state.read(cx).is_authenticated {
184 return true;
185 }
186
187 // Also check env var dynamically (in case migration happened after provider creation)
188 if let Some(ref auth_config) = self.auth_config {
189 if let Some(ref env_var_name) = auth_config.env_var {
190 let provider_id_string = self.provider_id_string();
191 let env_var_allowed = ExtensionSettings::get_global(cx)
192 .allowed_env_var_providers
193 .contains(provider_id_string.as_str());
194
195 if env_var_allowed {
196 if let Ok(value) = std::env::var(env_var_name) {
197 if !value.is_empty() {
198 return true;
199 }
200 }
201 }
202 }
203 }
204
205 false
206 }
207
208 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
209 let extension = self.extension.clone();
210 let provider_id = self.provider_info.id.clone();
211 let state = self.state.clone();
212
213 cx.spawn(async move |cx| {
214 let result = extension
215 .call({
216 let provider_id = provider_id.clone();
217 |extension, store| {
218 async move {
219 extension
220 .call_llm_provider_authenticate(store, &provider_id)
221 .await
222 }
223 .boxed()
224 }
225 })
226 .await;
227
228 match result {
229 Ok(Ok(Ok(()))) => {
230 // After successful auth, refresh the models list
231 let models_result = extension
232 .call({
233 let provider_id = provider_id.clone();
234 |ext, store| {
235 async move {
236 ext.call_llm_provider_models(store, &provider_id).await
237 }
238 .boxed()
239 }
240 })
241 .await;
242
243 let new_models: Vec<LlmModelInfo> = match models_result {
244 Ok(Ok(Ok(models))) => models,
245 _ => Vec::new(),
246 };
247
248 cx.update(|cx| {
249 state.update(cx, |state, cx| {
250 state.is_authenticated = true;
251 state.available_models = new_models;
252 cx.notify();
253 });
254 })?;
255 Ok(())
256 }
257 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
258 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
259 Err(e) => Err(AuthenticateError::Other(e)),
260 }
261 })
262 }
263
264 fn configuration_view(
265 &self,
266 _target_agent: ConfigurationViewTargetAgent,
267 window: &mut Window,
268 cx: &mut App,
269 ) -> AnyView {
270 let credential_key = self.credential_key();
271 let extension = self.extension.clone();
272 let extension_provider_id = self.provider_info.id.clone();
273 let full_provider_id = self.provider_id_string();
274 let state = self.state.clone();
275 let auth_config = self.auth_config.clone();
276
277 cx.new(|cx| {
278 ExtensionProviderConfigurationView::new(
279 credential_key,
280 extension,
281 extension_provider_id,
282 full_provider_id,
283 auth_config,
284 state,
285 window,
286 cx,
287 )
288 })
289 .into()
290 }
291
292 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
293 let extension = self.extension.clone();
294 let provider_id = self.provider_info.id.clone();
295 let state = self.state.clone();
296 let credential_key = self.credential_key();
297
298 let credentials_provider = <dyn CredentialsProvider>::global(cx);
299
300 cx.spawn(async move |cx| {
301 // Delete from system keychain
302 credentials_provider
303 .delete_credentials(&credential_key, cx)
304 .await
305 .log_err();
306
307 // Call extension's reset_credentials
308 let result = extension
309 .call(|extension, store| {
310 async move {
311 extension
312 .call_llm_provider_reset_credentials(store, &provider_id)
313 .await
314 }
315 .boxed()
316 })
317 .await;
318
319 // Update state
320 cx.update(|cx| {
321 state.update(cx, |state, _| {
322 state.is_authenticated = false;
323 });
324 })?;
325
326 match result {
327 Ok(Ok(Ok(()))) => Ok(()),
328 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
329 Ok(Err(e)) => Err(e),
330 Err(e) => Err(e),
331 }
332 })
333 }
334}
335
336impl LanguageModelProviderState for ExtensionLanguageModelProvider {
337 type ObservableEntity = ExtensionLlmProviderState;
338
339 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
340 Some(self.state.clone())
341 }
342
343 fn subscribe<T: 'static>(
344 &self,
345 cx: &mut Context<T>,
346 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
347 ) -> Option<Subscription> {
348 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
349 }
350}
351
352/// Configuration view for extension-based LLM providers.
353struct ExtensionProviderConfigurationView {
354 credential_key: String,
355 extension: WasmExtension,
356 extension_provider_id: String,
357 full_provider_id: String,
358 auth_config: Option<LanguageModelAuthConfig>,
359 state: Entity<ExtensionLlmProviderState>,
360 settings_markdown: Option<Entity<Markdown>>,
361 api_key_editor: Entity<Editor>,
362 loading_settings: bool,
363 loading_credentials: bool,
364 oauth_in_progress: bool,
365 oauth_error: Option<String>,
366 device_user_code: Option<String>,
367 _subscriptions: Vec<Subscription>,
368}
369
370impl ExtensionProviderConfigurationView {
371 fn new(
372 credential_key: String,
373 extension: WasmExtension,
374 extension_provider_id: String,
375 full_provider_id: String,
376 auth_config: Option<LanguageModelAuthConfig>,
377 state: Entity<ExtensionLlmProviderState>,
378 window: &mut Window,
379 cx: &mut Context<Self>,
380 ) -> Self {
381 // Subscribe to state changes
382 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
383 cx.notify();
384 });
385
386 // Create API key editor
387 let api_key_editor = cx.new(|cx| {
388 let mut editor = Editor::single_line(window, cx);
389 editor.set_placeholder_text("Enter API key...", window, cx);
390 editor
391 });
392
393 let mut this = Self {
394 credential_key,
395 extension,
396 extension_provider_id,
397 full_provider_id,
398 auth_config,
399 state,
400 settings_markdown: None,
401 api_key_editor,
402 loading_settings: true,
403 loading_credentials: true,
404 oauth_in_progress: false,
405 oauth_error: None,
406 device_user_code: None,
407 _subscriptions: vec![state_subscription],
408 };
409
410 // Load settings text from extension
411 this.load_settings_text(cx);
412
413 // Load existing credentials
414 this.load_credentials(cx);
415
416 this
417 }
418
419 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
420 let extension = self.extension.clone();
421 let provider_id = self.extension_provider_id.clone();
422
423 cx.spawn(async move |this, cx| {
424 let result = extension
425 .call({
426 let provider_id = provider_id.clone();
427 |ext, store| {
428 async move {
429 ext.call_llm_provider_settings_markdown(store, &provider_id)
430 .await
431 }
432 .boxed()
433 }
434 })
435 .await;
436
437 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
438
439 this.update(cx, |this, cx| {
440 this.loading_settings = false;
441 if let Some(text) = settings_text {
442 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
443 this.settings_markdown = Some(markdown);
444 }
445 cx.notify();
446 })
447 .log_err();
448 })
449 .detach();
450 }
451
452 fn load_credentials(&mut self, cx: &mut Context<Self>) {
453 let credential_key = self.credential_key.clone();
454 let credentials_provider = <dyn CredentialsProvider>::global(cx);
455 let state = self.state.clone();
456
457 // Check if we should use env var (already set in state during provider construction)
458 let api_key_from_env = self.state.read(cx).api_key_from_env;
459
460 cx.spawn(async move |this, cx| {
461 // If using env var, we're already authenticated
462 if api_key_from_env {
463 this.update(cx, |this, cx| {
464 this.loading_credentials = false;
465 cx.notify();
466 })
467 .log_err();
468 return;
469 }
470
471 let credentials = credentials_provider
472 .read_credentials(&credential_key, cx)
473 .await
474 .log_err()
475 .flatten();
476
477 let has_credentials = credentials.is_some();
478
479 // Update authentication state based on stored credentials
480 let _ = cx.update(|cx| {
481 state.update(cx, |state, cx| {
482 state.is_authenticated = has_credentials;
483 cx.notify();
484 });
485 });
486
487 this.update(cx, |this, cx| {
488 this.loading_credentials = false;
489 cx.notify();
490 })
491 .log_err();
492 })
493 .detach();
494 }
495
496 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
497 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
498 let env_var_name = match &self.auth_config {
499 Some(config) => config.env_var.clone(),
500 None => return,
501 };
502
503 let state = self.state.clone();
504 let currently_allowed = self.state.read(cx).env_var_allowed;
505
506 // Update settings file
507 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
508 let providers = settings
509 .extension
510 .allowed_env_var_providers
511 .get_or_insert_with(Vec::new);
512
513 if currently_allowed {
514 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
515 } else {
516 if !providers
517 .iter()
518 .any(|id| id.as_ref() == full_provider_id.as_ref())
519 {
520 providers.push(full_provider_id.clone());
521 }
522 }
523 });
524
525 // Update local state
526 let new_allowed = !currently_allowed;
527 let new_from_env = if new_allowed {
528 if let Some(var_name) = &env_var_name {
529 if let Ok(value) = std::env::var(var_name) {
530 !value.is_empty()
531 } else {
532 false
533 }
534 } else {
535 false
536 }
537 } else {
538 false
539 };
540
541 state.update(cx, |state, cx| {
542 state.env_var_allowed = new_allowed;
543 state.api_key_from_env = new_from_env;
544 if new_from_env {
545 state.is_authenticated = true;
546 }
547 cx.notify();
548 });
549
550 // If env var is being enabled, clear any stored keychain credentials
551 // so there's only one source of truth for the API key
552 if new_allowed {
553 let credential_key = self.credential_key.clone();
554 let credentials_provider = <dyn CredentialsProvider>::global(cx);
555 cx.spawn(async move |_this, cx| {
556 credentials_provider
557 .delete_credentials(&credential_key, cx)
558 .await
559 .log_err();
560 })
561 .detach();
562 }
563
564 // If env var is being disabled, reload credentials from keychain
565 if !new_allowed {
566 self.reload_keychain_credentials(cx);
567 }
568
569 cx.notify();
570 }
571
572 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
573 let credential_key = self.credential_key.clone();
574 let credentials_provider = <dyn CredentialsProvider>::global(cx);
575 let state = self.state.clone();
576
577 cx.spawn(async move |_this, cx| {
578 let credentials = credentials_provider
579 .read_credentials(&credential_key, cx)
580 .await
581 .log_err()
582 .flatten();
583
584 let has_credentials = credentials.is_some();
585
586 let _ = cx.update(|cx| {
587 state.update(cx, |state, cx| {
588 state.is_authenticated = has_credentials;
589 cx.notify();
590 });
591 });
592 })
593 .detach();
594 }
595
596 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
597 let api_key = self.api_key_editor.read(cx).text(cx);
598 if api_key.is_empty() {
599 return;
600 }
601
602 // Clear the editor
603 self.api_key_editor
604 .update(cx, |editor, cx| editor.set_text("", window, cx));
605
606 let credential_key = self.credential_key.clone();
607 let credentials_provider = <dyn CredentialsProvider>::global(cx);
608 let state = self.state.clone();
609
610 cx.spawn(async move |_this, cx| {
611 // Store in system keychain
612 credentials_provider
613 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
614 .await
615 .log_err();
616
617 // Update state to authenticated
618 let _ = cx.update(|cx| {
619 state.update(cx, |state, cx| {
620 state.is_authenticated = true;
621 cx.notify();
622 });
623 });
624 })
625 .detach();
626 }
627
628 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
629 // Clear the editor
630 self.api_key_editor
631 .update(cx, |editor, cx| editor.set_text("", window, cx));
632
633 let credential_key = self.credential_key.clone();
634 let credentials_provider = <dyn CredentialsProvider>::global(cx);
635 let state = self.state.clone();
636
637 cx.spawn(async move |_this, cx| {
638 // Delete from system keychain
639 credentials_provider
640 .delete_credentials(&credential_key, cx)
641 .await
642 .log_err();
643
644 // Update state to unauthenticated
645 let _ = cx.update(|cx| {
646 state.update(cx, |state, cx| {
647 state.is_authenticated = false;
648 cx.notify();
649 });
650 });
651 })
652 .detach();
653 }
654
655 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
656 if self.oauth_in_progress {
657 return;
658 }
659
660 self.oauth_in_progress = true;
661 self.oauth_error = None;
662 self.device_user_code = None;
663 cx.notify();
664
665 let extension = self.extension.clone();
666 let provider_id = self.extension_provider_id.clone();
667 let state = self.state.clone();
668
669 cx.spawn(async move |this, cx| {
670 // Step 1: Start device flow - opens browser and returns user code
671 let start_result = extension
672 .call({
673 let provider_id = provider_id.clone();
674 |ext, store| {
675 async move {
676 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
677 .await
678 }
679 .boxed()
680 }
681 })
682 .await;
683
684 let user_code = match start_result {
685 Ok(Ok(Ok(code))) => code,
686 Ok(Ok(Err(e))) => {
687 log::error!("Device flow start failed: {}", e);
688 this.update(cx, |this, cx| {
689 this.oauth_in_progress = false;
690 this.oauth_error = Some(e);
691 cx.notify();
692 })
693 .log_err();
694 return;
695 }
696 Ok(Err(e)) | Err(e) => {
697 log::error!("Device flow start error: {}", e);
698 this.update(cx, |this, cx| {
699 this.oauth_in_progress = false;
700 this.oauth_error = Some(e.to_string());
701 cx.notify();
702 })
703 .log_err();
704 return;
705 }
706 };
707
708 // Update UI to show the user code before polling
709 this.update(cx, |this, cx| {
710 this.device_user_code = Some(user_code);
711 cx.notify();
712 })
713 .log_err();
714
715 // Step 2: Poll for authentication completion
716 let poll_result = extension
717 .call({
718 let provider_id = provider_id.clone();
719 |ext, store| {
720 async move {
721 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
722 .await
723 }
724 .boxed()
725 }
726 })
727 .await;
728
729 let error_message = match poll_result {
730 Ok(Ok(Ok(()))) => {
731 // After successful auth, refresh the models list
732 let models_result = extension
733 .call({
734 let provider_id = provider_id.clone();
735 |ext, store| {
736 async move {
737 ext.call_llm_provider_models(store, &provider_id).await
738 }
739 .boxed()
740 }
741 })
742 .await;
743
744 let new_models: Vec<LlmModelInfo> = match models_result {
745 Ok(Ok(Ok(models))) => models,
746 _ => Vec::new(),
747 };
748
749 let _ = cx.update(|cx| {
750 state.update(cx, |state, cx| {
751 state.is_authenticated = true;
752 state.available_models = new_models;
753 cx.notify();
754 });
755 });
756 None
757 }
758 Ok(Ok(Err(e))) => {
759 log::error!("Device flow poll failed: {}", e);
760 Some(e)
761 }
762 Ok(Err(e)) | Err(e) => {
763 log::error!("Device flow poll error: {}", e);
764 Some(e.to_string())
765 }
766 };
767
768 this.update(cx, |this, cx| {
769 this.oauth_in_progress = false;
770 this.oauth_error = error_message;
771 this.device_user_code = None;
772 cx.notify();
773 })
774 .log_err();
775 })
776 .detach();
777 }
778
779 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
780 self.state.read(cx).is_authenticated
781 }
782
783 fn has_oauth_config(&self) -> bool {
784 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
785 }
786
787 fn oauth_config(&self) -> Option<&OAuthConfig> {
788 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
789 }
790
791 fn has_api_key_config(&self) -> bool {
792 // API key is available if there's a credential_label or no oauth-only config
793 self.auth_config
794 .as_ref()
795 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
796 .unwrap_or(true)
797 }
798}
799
800impl gpui::Render for ExtensionProviderConfigurationView {
801 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
802 let is_loading = self.loading_settings || self.loading_credentials;
803 let is_authenticated = self.is_authenticated(cx);
804 let env_var_allowed = self.state.read(cx).env_var_allowed;
805 let api_key_from_env = self.state.read(cx).api_key_from_env;
806 let has_oauth = self.has_oauth_config();
807 let has_api_key = self.has_api_key_config();
808
809 if is_loading {
810 return v_flex()
811 .gap_2()
812 .child(Label::new("Loading...").color(Color::Muted))
813 .into_any_element();
814 }
815
816 let mut content = v_flex().gap_4().size_full();
817
818 // Render settings markdown if available
819 if let Some(markdown) = &self.settings_markdown {
820 let style = settings_markdown_style(_window, cx);
821 content = content.child(
822 div()
823 .p_2()
824 .rounded_md()
825 .bg(cx.theme().colors().surface_background)
826 .child(MarkdownElement::new(markdown.clone(), style)),
827 );
828 }
829
830 // Render env var checkbox if the extension specifies an env var
831 if let Some(auth_config) = &self.auth_config {
832 if let Some(env_var_name) = &auth_config.env_var {
833 let env_var_name = env_var_name.clone();
834 let checkbox_label =
835 format!("Read API key from {} environment variable", env_var_name);
836
837 content = content.child(
838 h_flex()
839 .gap_2()
840 .child(
841 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
842 .on_click(cx.listener(|this, _, _window, cx| {
843 this.toggle_env_var_permission(cx);
844 })),
845 )
846 .child(Label::new(checkbox_label).size(LabelSize::Small)),
847 );
848
849 // Show status if env var is allowed
850 if env_var_allowed {
851 if api_key_from_env {
852 content = content.child(
853 h_flex()
854 .gap_2()
855 .child(
856 ui::Icon::new(ui::IconName::Check)
857 .color(Color::Success)
858 .size(ui::IconSize::Small),
859 )
860 .child(
861 Label::new(format!("API key loaded from {}", env_var_name))
862 .color(Color::Success),
863 ),
864 );
865 return content.into_any_element();
866 } else {
867 content = content.child(
868 h_flex()
869 .gap_2()
870 .child(
871 ui::Icon::new(ui::IconName::Warning)
872 .color(Color::Warning)
873 .size(ui::IconSize::Small),
874 )
875 .child(
876 Label::new(format!(
877 "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
878 env_var_name
879 ))
880 .color(Color::Warning)
881 .size(LabelSize::Small),
882 ),
883 );
884 }
885 }
886 }
887 }
888
889 // If authenticated, show success state with sign out option
890 if is_authenticated && !api_key_from_env {
891 let reset_label = if has_oauth && !has_api_key {
892 "Sign Out"
893 } else {
894 "Reset Credentials"
895 };
896
897 let status_label = if has_oauth && !has_api_key {
898 "Signed in"
899 } else {
900 "Authenticated"
901 };
902
903 content = content.child(
904 v_flex()
905 .gap_2()
906 .child(
907 h_flex()
908 .gap_2()
909 .child(
910 ui::Icon::new(ui::IconName::Check)
911 .color(Color::Success)
912 .size(ui::IconSize::Small),
913 )
914 .child(Label::new(status_label).color(Color::Success)),
915 )
916 .child(
917 ui::Button::new("reset-credentials", reset_label)
918 .style(ui::ButtonStyle::Subtle)
919 .on_click(cx.listener(|this, _, window, cx| {
920 this.reset_api_key(window, cx);
921 })),
922 ),
923 );
924
925 return content.into_any_element();
926 }
927
928 // Not authenticated - show available auth options
929 if !api_key_from_env {
930 // Render OAuth sign-in button if configured
931 if has_oauth {
932 let oauth_config = self.oauth_config();
933 let button_label = oauth_config
934 .and_then(|c| c.sign_in_button_label.clone())
935 .unwrap_or_else(|| "Sign In".to_string());
936 let button_icon = oauth_config
937 .and_then(|c| c.sign_in_button_icon.as_ref())
938 .and_then(|icon_name| match icon_name.as_str() {
939 "github" => Some(ui::IconName::Github),
940 _ => None,
941 });
942
943 let oauth_in_progress = self.oauth_in_progress;
944
945 let oauth_error = self.oauth_error.clone();
946
947 content = content.child(
948 v_flex()
949 .gap_2()
950 .child({
951 let mut button = ui::Button::new("oauth-sign-in", button_label)
952 .full_width()
953 .style(ui::ButtonStyle::Outlined)
954 .disabled(oauth_in_progress)
955 .on_click(cx.listener(|this, _, _window, cx| {
956 this.start_oauth_sign_in(cx);
957 }));
958 if let Some(icon) = button_icon {
959 button = button
960 .icon(icon)
961 .icon_position(ui::IconPosition::Start)
962 .icon_size(ui::IconSize::Small)
963 .icon_color(Color::Muted);
964 }
965 button
966 })
967 .when(oauth_in_progress, |this| {
968 let user_code = self.device_user_code.clone();
969 this.child(
970 v_flex()
971 .gap_1()
972 .when_some(user_code, |this, code| {
973 let copied = cx
974 .read_from_clipboard()
975 .map(|item| item.text().as_ref() == Some(&code))
976 .unwrap_or(false);
977 let code_for_click = code.clone();
978 this.child(
979 h_flex()
980 .gap_1()
981 .child(
982 Label::new("Enter code:")
983 .size(LabelSize::Small)
984 .color(Color::Muted),
985 )
986 .child(
987 h_flex()
988 .gap_1()
989 .px_1()
990 .border_1()
991 .border_color(cx.theme().colors().border)
992 .rounded_sm()
993 .cursor_pointer()
994 .on_mouse_down(
995 MouseButton::Left,
996 move |_, window, cx| {
997 cx.write_to_clipboard(
998 ClipboardItem::new_string(
999 code_for_click.clone(),
1000 ),
1001 );
1002 window.refresh();
1003 },
1004 )
1005 .child(
1006 Label::new(code)
1007 .size(LabelSize::Small)
1008 .color(Color::Accent),
1009 )
1010 .child(
1011 ui::Icon::new(if copied {
1012 ui::IconName::Check
1013 } else {
1014 ui::IconName::Copy
1015 })
1016 .size(ui::IconSize::Small)
1017 .color(if copied {
1018 Color::Success
1019 } else {
1020 Color::Muted
1021 }),
1022 ),
1023 ),
1024 )
1025 })
1026 .child(
1027 Label::new("Waiting for authorization in browser...")
1028 .size(LabelSize::Small)
1029 .color(Color::Muted),
1030 ),
1031 )
1032 })
1033 .when_some(oauth_error, |this, error| {
1034 this.child(
1035 v_flex()
1036 .gap_1()
1037 .child(
1038 h_flex()
1039 .gap_2()
1040 .child(
1041 ui::Icon::new(ui::IconName::Warning)
1042 .color(Color::Error)
1043 .size(ui::IconSize::Small),
1044 )
1045 .child(
1046 Label::new("Authentication failed")
1047 .color(Color::Error)
1048 .size(LabelSize::Small),
1049 ),
1050 )
1051 .child(
1052 div().pl_6().child(
1053 Label::new(error)
1054 .color(Color::Error)
1055 .size(LabelSize::Small),
1056 ),
1057 ),
1058 )
1059 }),
1060 );
1061 }
1062
1063 // Render API key input if configured (and we have both options, show a separator)
1064 if has_api_key {
1065 if has_oauth {
1066 content = content.child(
1067 h_flex()
1068 .gap_2()
1069 .items_center()
1070 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1071 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1072 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1073 );
1074 }
1075
1076 let credential_label = self
1077 .auth_config
1078 .as_ref()
1079 .and_then(|c| c.credential_label.clone())
1080 .unwrap_or_else(|| "API Key".to_string());
1081
1082 content = content.child(
1083 v_flex()
1084 .gap_2()
1085 .on_action(cx.listener(Self::save_api_key))
1086 .child(
1087 Label::new(credential_label)
1088 .size(LabelSize::Small)
1089 .color(Color::Muted),
1090 )
1091 .child(self.api_key_editor.clone())
1092 .child(
1093 Label::new("Enter your API key and press Enter to save")
1094 .size(LabelSize::Small)
1095 .color(Color::Muted),
1096 ),
1097 );
1098 }
1099 }
1100
1101 content.into_any_element()
1102 }
1103}
1104
1105impl Focusable for ExtensionProviderConfigurationView {
1106 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1107 self.api_key_editor.focus_handle(cx)
1108 }
1109}
1110
1111fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1112 let theme_settings = ThemeSettings::get_global(cx);
1113 let colors = cx.theme().colors();
1114 let mut text_style = window.text_style();
1115 text_style.refine(&TextStyleRefinement {
1116 font_family: Some(theme_settings.ui_font.family.clone()),
1117 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1118 font_features: Some(theme_settings.ui_font.features.clone()),
1119 color: Some(colors.text),
1120 ..Default::default()
1121 });
1122
1123 MarkdownStyle {
1124 base_text_style: text_style,
1125 selection_background_color: colors.element_selection_background,
1126 inline_code: TextStyleRefinement {
1127 background_color: Some(colors.editor_background),
1128 ..Default::default()
1129 },
1130 link: TextStyleRefinement {
1131 color: Some(colors.text_accent),
1132 underline: Some(UnderlineStyle {
1133 color: Some(colors.text_accent.opacity(0.5)),
1134 thickness: px(1.),
1135 ..Default::default()
1136 }),
1137 ..Default::default()
1138 },
1139 syntax: cx.theme().syntax().clone(),
1140 ..Default::default()
1141 }
1142}
1143
1144/// An extension-based language model.
1145pub struct ExtensionLanguageModel {
1146 extension: WasmExtension,
1147 model_info: LlmModelInfo,
1148 provider_id: LanguageModelProviderId,
1149 provider_name: LanguageModelProviderName,
1150 provider_info: LlmProviderInfo,
1151}
1152
1153impl LanguageModel for ExtensionLanguageModel {
1154 fn id(&self) -> LanguageModelId {
1155 LanguageModelId::from(self.model_info.id.clone())
1156 }
1157
1158 fn name(&self) -> LanguageModelName {
1159 LanguageModelName::from(self.model_info.name.clone())
1160 }
1161
1162 fn provider_id(&self) -> LanguageModelProviderId {
1163 self.provider_id.clone()
1164 }
1165
1166 fn provider_name(&self) -> LanguageModelProviderName {
1167 self.provider_name.clone()
1168 }
1169
1170 fn telemetry_id(&self) -> String {
1171 format!("extension-{}", self.model_info.id)
1172 }
1173
1174 fn supports_images(&self) -> bool {
1175 self.model_info.capabilities.supports_images
1176 }
1177
1178 fn supports_tools(&self) -> bool {
1179 self.model_info.capabilities.supports_tools
1180 }
1181
1182 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1183 match choice {
1184 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1185 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1186 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1187 }
1188 }
1189
1190 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1191 match self.model_info.capabilities.tool_input_format {
1192 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1193 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1194 }
1195 }
1196
1197 fn max_token_count(&self) -> u64 {
1198 self.model_info.max_token_count
1199 }
1200
1201 fn max_output_tokens(&self) -> Option<u64> {
1202 self.model_info.max_output_tokens
1203 }
1204
1205 fn count_tokens(
1206 &self,
1207 request: LanguageModelRequest,
1208 cx: &App,
1209 ) -> BoxFuture<'static, Result<u64>> {
1210 let extension = self.extension.clone();
1211 let provider_id = self.provider_info.id.clone();
1212 let model_id = self.model_info.id.clone();
1213
1214 let wit_request = convert_request_to_wit(request);
1215
1216 cx.background_spawn(async move {
1217 extension
1218 .call({
1219 let provider_id = provider_id.clone();
1220 let model_id = model_id.clone();
1221 let wit_request = wit_request.clone();
1222 |ext, store| {
1223 async move {
1224 let count = ext
1225 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1226 .await?
1227 .map_err(|e| anyhow!("{}", e))?;
1228 Ok(count)
1229 }
1230 .boxed()
1231 }
1232 })
1233 .await?
1234 })
1235 .boxed()
1236 }
1237
1238 fn stream_completion(
1239 &self,
1240 request: LanguageModelRequest,
1241 _cx: &AsyncApp,
1242 ) -> BoxFuture<
1243 'static,
1244 Result<
1245 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1246 LanguageModelCompletionError,
1247 >,
1248 > {
1249 let extension = self.extension.clone();
1250 let provider_id = self.provider_info.id.clone();
1251 let model_id = self.model_info.id.clone();
1252
1253 let wit_request = convert_request_to_wit(request);
1254
1255 async move {
1256 // Start the stream
1257 let stream_id_result = extension
1258 .call({
1259 let provider_id = provider_id.clone();
1260 let model_id = model_id.clone();
1261 let wit_request = wit_request.clone();
1262 |ext, store| {
1263 async move {
1264 let id = ext
1265 .call_llm_stream_completion_start(
1266 store,
1267 &provider_id,
1268 &model_id,
1269 &wit_request,
1270 )
1271 .await?
1272 .map_err(|e| anyhow!("{}", e))?;
1273 Ok(id)
1274 }
1275 .boxed()
1276 }
1277 })
1278 .await;
1279
1280 let stream_id = stream_id_result
1281 .map_err(LanguageModelCompletionError::Other)?
1282 .map_err(LanguageModelCompletionError::Other)?;
1283
1284 // Create a stream that polls for events
1285 let stream = futures::stream::unfold(
1286 (extension.clone(), stream_id, false),
1287 move |(extension, stream_id, done)| async move {
1288 if done {
1289 return None;
1290 }
1291
1292 let result = extension
1293 .call({
1294 let stream_id = stream_id.clone();
1295 |ext, store| {
1296 async move {
1297 let event = ext
1298 .call_llm_stream_completion_next(store, &stream_id)
1299 .await?
1300 .map_err(|e| anyhow!("{}", e))?;
1301 Ok(event)
1302 }
1303 .boxed()
1304 }
1305 })
1306 .await
1307 .and_then(|inner| inner);
1308
1309 match result {
1310 Ok(Some(event)) => {
1311 let converted = convert_completion_event(event);
1312 let is_done =
1313 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1314 Some((converted, (extension, stream_id, is_done)))
1315 }
1316 Ok(None) => {
1317 // Stream complete, close it
1318 let _ = extension
1319 .call({
1320 let stream_id = stream_id.clone();
1321 |ext, store| {
1322 async move {
1323 ext.call_llm_stream_completion_close(store, &stream_id)
1324 .await?;
1325 Ok::<(), anyhow::Error>(())
1326 }
1327 .boxed()
1328 }
1329 })
1330 .await;
1331 None
1332 }
1333 Err(e) => Some((
1334 Err(LanguageModelCompletionError::Other(e)),
1335 (extension, stream_id, true),
1336 )),
1337 }
1338 },
1339 );
1340
1341 Ok(stream.boxed())
1342 }
1343 .boxed()
1344 }
1345
1346 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1347 // Extensions can implement this via llm_cache_configuration
1348 None
1349 }
1350}
1351
1352fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1353 use language_model::{MessageContent, Role};
1354
1355 let messages: Vec<LlmRequestMessage> = request
1356 .messages
1357 .into_iter()
1358 .map(|msg| {
1359 let role = match msg.role {
1360 Role::User => LlmMessageRole::User,
1361 Role::Assistant => LlmMessageRole::Assistant,
1362 Role::System => LlmMessageRole::System,
1363 };
1364
1365 let content: Vec<LlmMessageContent> = msg
1366 .content
1367 .into_iter()
1368 .map(|c| match c {
1369 MessageContent::Text(text) => LlmMessageContent::Text(text),
1370 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1371 source: image.source.to_string(),
1372 width: Some(image.size.width.0 as u32),
1373 height: Some(image.size.height.0 as u32),
1374 }),
1375 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1376 id: tool_use.id.to_string(),
1377 name: tool_use.name.to_string(),
1378 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1379 thought_signature: tool_use.thought_signature,
1380 }),
1381 MessageContent::ToolResult(tool_result) => {
1382 let content = match tool_result.content {
1383 language_model::LanguageModelToolResultContent::Text(text) => {
1384 LlmToolResultContent::Text(text.to_string())
1385 }
1386 language_model::LanguageModelToolResultContent::Image(image) => {
1387 LlmToolResultContent::Image(LlmImageData {
1388 source: image.source.to_string(),
1389 width: Some(image.size.width.0 as u32),
1390 height: Some(image.size.height.0 as u32),
1391 })
1392 }
1393 };
1394 LlmMessageContent::ToolResult(LlmToolResult {
1395 tool_use_id: tool_result.tool_use_id.to_string(),
1396 tool_name: tool_result.tool_name.to_string(),
1397 is_error: tool_result.is_error,
1398 content,
1399 })
1400 }
1401 MessageContent::Thinking { text, signature } => {
1402 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1403 }
1404 MessageContent::RedactedThinking(data) => {
1405 LlmMessageContent::RedactedThinking(data)
1406 }
1407 })
1408 .collect();
1409
1410 LlmRequestMessage {
1411 role,
1412 content,
1413 cache: msg.cache,
1414 }
1415 })
1416 .collect();
1417
1418 let tools: Vec<LlmToolDefinition> = request
1419 .tools
1420 .into_iter()
1421 .map(|tool| LlmToolDefinition {
1422 name: tool.name,
1423 description: tool.description,
1424 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1425 })
1426 .collect();
1427
1428 let tool_choice = request.tool_choice.map(|tc| match tc {
1429 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1430 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1431 LanguageModelToolChoice::None => LlmToolChoice::None,
1432 });
1433
1434 LlmCompletionRequest {
1435 messages,
1436 tools,
1437 tool_choice,
1438 stop_sequences: request.stop,
1439 temperature: request.temperature,
1440 thinking_allowed: false,
1441 max_tokens: None,
1442 }
1443}
1444
1445fn convert_completion_event(
1446 event: LlmCompletionEvent,
1447) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1448 match event {
1449 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1450 message_id: String::new(),
1451 }),
1452 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1453 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1454 text: thinking.text,
1455 signature: thinking.signature,
1456 }),
1457 LlmCompletionEvent::RedactedThinking(data) => {
1458 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1459 }
1460 LlmCompletionEvent::ToolUse(tool_use) => {
1461 let raw_input = tool_use.input.clone();
1462 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1463 Ok(LanguageModelCompletionEvent::ToolUse(
1464 LanguageModelToolUse {
1465 id: LanguageModelToolUseId::from(tool_use.id),
1466 name: tool_use.name.into(),
1467 raw_input,
1468 input,
1469 is_input_complete: true,
1470 thought_signature: tool_use.thought_signature,
1471 },
1472 ))
1473 }
1474 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1475 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1476 id: LanguageModelToolUseId::from(error.id),
1477 tool_name: error.tool_name.into(),
1478 raw_input: error.raw_input.into(),
1479 json_parse_error: error.error,
1480 })
1481 }
1482 LlmCompletionEvent::Stop(reason) => {
1483 let stop_reason = match reason {
1484 LlmStopReason::EndTurn => StopReason::EndTurn,
1485 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1486 LlmStopReason::ToolUse => StopReason::ToolUse,
1487 LlmStopReason::Refusal => StopReason::Refusal,
1488 };
1489 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1490 }
1491 LlmCompletionEvent::Usage(usage) => {
1492 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1493 input_tokens: usage.input_tokens,
1494 output_tokens: usage.output_tokens,
1495 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1496 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1497 }))
1498 }
1499 LlmCompletionEvent::ReasoningDetails(json) => {
1500 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1501 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1502 ))
1503 }
1504 }
1505}