1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::{LanguageModelAuthConfig, OAuthConfig};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
20 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 // First check cached state
183 if self.state.read(cx).is_authenticated {
184 return true;
185 }
186
187 // Also check env var dynamically (in case migration happened after provider creation)
188 if let Some(ref auth_config) = self.auth_config {
189 if let Some(ref env_var_name) = auth_config.env_var {
190 let provider_id_string = self.provider_id_string();
191 let env_var_allowed = ExtensionSettings::get_global(cx)
192 .allowed_env_var_providers
193 .contains(provider_id_string.as_str());
194
195 if env_var_allowed {
196 if let Ok(value) = std::env::var(env_var_name) {
197 if !value.is_empty() {
198 return true;
199 }
200 }
201 }
202 }
203 }
204
205 false
206 }
207
208 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
209 // Check if already authenticated via is_authenticated
210 if self.is_authenticated(cx) {
211 return Task::ready(Ok(()));
212 }
213
214 // Not authenticated - return error indicating credentials not found
215 Task::ready(Err(AuthenticateError::CredentialsNotFound))
216 }
217
218 fn configuration_view(
219 &self,
220 _target_agent: ConfigurationViewTargetAgent,
221 window: &mut Window,
222 cx: &mut App,
223 ) -> AnyView {
224 let credential_key = self.credential_key();
225 let extension = self.extension.clone();
226 let extension_provider_id = self.provider_info.id.clone();
227 let full_provider_id = self.provider_id_string();
228 let state = self.state.clone();
229 let auth_config = self.auth_config.clone();
230
231 cx.new(|cx| {
232 ExtensionProviderConfigurationView::new(
233 credential_key,
234 extension,
235 extension_provider_id,
236 full_provider_id,
237 auth_config,
238 state,
239 window,
240 cx,
241 )
242 })
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 let extension = self.extension.clone();
248 let provider_id = self.provider_info.id.clone();
249 let state = self.state.clone();
250 let credential_key = self.credential_key();
251
252 let credentials_provider = <dyn CredentialsProvider>::global(cx);
253
254 cx.spawn(async move |cx| {
255 // Delete from system keychain
256 credentials_provider
257 .delete_credentials(&credential_key, cx)
258 .await
259 .log_err();
260
261 // Call extension's reset_credentials
262 let result = extension
263 .call(|extension, store| {
264 async move {
265 extension
266 .call_llm_provider_reset_credentials(store, &provider_id)
267 .await
268 }
269 .boxed()
270 })
271 .await;
272
273 // Update state
274 cx.update(|cx| {
275 state.update(cx, |state, _| {
276 state.is_authenticated = false;
277 });
278 })?;
279
280 match result {
281 Ok(Ok(Ok(()))) => Ok(()),
282 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
283 Ok(Err(e)) => Err(e),
284 Err(e) => Err(e),
285 }
286 })
287 }
288}
289
290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
291 type ObservableEntity = ExtensionLlmProviderState;
292
293 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
294 Some(self.state.clone())
295 }
296
297 fn subscribe<T: 'static>(
298 &self,
299 cx: &mut Context<T>,
300 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
301 ) -> Option<Subscription> {
302 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
303 }
304}
305
306/// Configuration view for extension-based LLM providers.
307struct ExtensionProviderConfigurationView {
308 credential_key: String,
309 extension: WasmExtension,
310 extension_provider_id: String,
311 full_provider_id: String,
312 auth_config: Option<LanguageModelAuthConfig>,
313 state: Entity<ExtensionLlmProviderState>,
314 settings_markdown: Option<Entity<Markdown>>,
315 api_key_editor: Entity<Editor>,
316 loading_settings: bool,
317 loading_credentials: bool,
318 oauth_in_progress: bool,
319 oauth_error: Option<String>,
320 device_user_code: Option<String>,
321 _subscriptions: Vec<Subscription>,
322}
323
324impl ExtensionProviderConfigurationView {
325 fn new(
326 credential_key: String,
327 extension: WasmExtension,
328 extension_provider_id: String,
329 full_provider_id: String,
330 auth_config: Option<LanguageModelAuthConfig>,
331 state: Entity<ExtensionLlmProviderState>,
332 window: &mut Window,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 // Subscribe to state changes
336 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
337 cx.notify();
338 });
339
340 // Create API key editor
341 let api_key_editor = cx.new(|cx| {
342 let mut editor = Editor::single_line(window, cx);
343 editor.set_placeholder_text("Enter API key...", window, cx);
344 editor
345 });
346
347 let mut this = Self {
348 credential_key,
349 extension,
350 extension_provider_id,
351 full_provider_id,
352 auth_config,
353 state,
354 settings_markdown: None,
355 api_key_editor,
356 loading_settings: true,
357 loading_credentials: true,
358 oauth_in_progress: false,
359 oauth_error: None,
360 device_user_code: None,
361 _subscriptions: vec![state_subscription],
362 };
363
364 // Load settings text from extension
365 this.load_settings_text(cx);
366
367 // Load existing credentials
368 this.load_credentials(cx);
369
370 this
371 }
372
373 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
374 let extension = self.extension.clone();
375 let provider_id = self.extension_provider_id.clone();
376
377 cx.spawn(async move |this, cx| {
378 let result = extension
379 .call({
380 let provider_id = provider_id.clone();
381 |ext, store| {
382 async move {
383 ext.call_llm_provider_settings_markdown(store, &provider_id)
384 .await
385 }
386 .boxed()
387 }
388 })
389 .await;
390
391 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
392
393 this.update(cx, |this, cx| {
394 this.loading_settings = false;
395 if let Some(text) = settings_text {
396 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
397 this.settings_markdown = Some(markdown);
398 }
399 cx.notify();
400 })
401 .log_err();
402 })
403 .detach();
404 }
405
406 fn load_credentials(&mut self, cx: &mut Context<Self>) {
407 let credential_key = self.credential_key.clone();
408 let credentials_provider = <dyn CredentialsProvider>::global(cx);
409 let state = self.state.clone();
410
411 // Check if we should use env var (already set in state during provider construction)
412 let api_key_from_env = self.state.read(cx).api_key_from_env;
413
414 cx.spawn(async move |this, cx| {
415 // If using env var, we're already authenticated
416 if api_key_from_env {
417 this.update(cx, |this, cx| {
418 this.loading_credentials = false;
419 cx.notify();
420 })
421 .log_err();
422 return;
423 }
424
425 let credentials = credentials_provider
426 .read_credentials(&credential_key, cx)
427 .await
428 .log_err()
429 .flatten();
430
431 // Treat empty credentials as not authenticated (used as migration marker)
432 let has_credentials = credentials
433 .map(|(_, password)| !password.is_empty())
434 .unwrap_or(false);
435
436 // Update authentication state based on stored credentials
437 cx.update(|cx| {
438 state.update(cx, |state, cx| {
439 state.is_authenticated = has_credentials;
440 cx.notify();
441 });
442 })
443 .log_err();
444
445 this.update(cx, |this, cx| {
446 this.loading_credentials = false;
447 cx.notify();
448 })
449 .log_err();
450 })
451 .detach();
452 }
453
454 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
455 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
456 let env_var_name = match &self.auth_config {
457 Some(config) => config.env_var.clone(),
458 None => return,
459 };
460
461 let state = self.state.clone();
462 let currently_allowed = self.state.read(cx).env_var_allowed;
463
464 // Update settings file
465 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
466 let providers = settings
467 .extension
468 .allowed_env_var_providers
469 .get_or_insert_with(Vec::new);
470
471 if currently_allowed {
472 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
473 } else {
474 if !providers
475 .iter()
476 .any(|id| id.as_ref() == full_provider_id.as_ref())
477 {
478 providers.push(full_provider_id.clone());
479 }
480 }
481 });
482
483 // Update local state
484 let new_allowed = !currently_allowed;
485 let new_from_env = if new_allowed {
486 if let Some(var_name) = &env_var_name {
487 if let Ok(value) = std::env::var(var_name) {
488 !value.is_empty()
489 } else {
490 false
491 }
492 } else {
493 false
494 }
495 } else {
496 false
497 };
498
499 state.update(cx, |state, cx| {
500 state.env_var_allowed = new_allowed;
501 state.api_key_from_env = new_from_env;
502 if new_from_env {
503 state.is_authenticated = true;
504 }
505 cx.notify();
506 });
507
508 // If env var is being enabled, clear any stored keychain credentials
509 // so there's only one source of truth for the API key
510 if new_allowed {
511 let credential_key = self.credential_key.clone();
512 let credentials_provider = <dyn CredentialsProvider>::global(cx);
513 cx.spawn(async move |_this, cx| {
514 credentials_provider
515 .delete_credentials(&credential_key, cx)
516 .await
517 .log_err();
518 })
519 .detach();
520 }
521
522 // If env var is being disabled, reload credentials from keychain
523 if !new_allowed {
524 self.reload_keychain_credentials(cx);
525 }
526
527 cx.notify();
528 }
529
530 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
531 let credential_key = self.credential_key.clone();
532 let credentials_provider = <dyn CredentialsProvider>::global(cx);
533 let state = self.state.clone();
534
535 cx.spawn(async move |_this, cx| {
536 let credentials = credentials_provider
537 .read_credentials(&credential_key, cx)
538 .await
539 .log_err()
540 .flatten();
541
542 // Treat empty credentials as not authenticated (used as migration marker)
543 let has_credentials = credentials
544 .map(|(_, password)| !password.is_empty())
545 .unwrap_or(false);
546
547 cx.update(|cx| {
548 state.update(cx, |state, cx| {
549 state.is_authenticated = has_credentials;
550 cx.notify();
551 });
552 })
553 .log_err();
554 })
555 .detach();
556 }
557
558 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
559 let api_key = self.api_key_editor.read(cx).text(cx);
560 if api_key.is_empty() {
561 return;
562 }
563
564 // Clear the editor
565 self.api_key_editor
566 .update(cx, |editor, cx| editor.set_text("", window, cx));
567
568 let credential_key = self.credential_key.clone();
569 let credentials_provider = <dyn CredentialsProvider>::global(cx);
570 let state = self.state.clone();
571
572 cx.spawn(async move |_this, cx| {
573 // Store in system keychain
574 credentials_provider
575 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
576 .await
577 .log_err();
578
579 // Update state to authenticated
580 cx.update(|cx| {
581 state.update(cx, |state, cx| {
582 state.is_authenticated = true;
583 cx.notify();
584 });
585 })
586 .log_err();
587 })
588 .detach();
589 }
590
591 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
592 // Clear the editor
593 self.api_key_editor
594 .update(cx, |editor, cx| editor.set_text("", window, cx));
595
596 let credential_key = self.credential_key.clone();
597 let credentials_provider = <dyn CredentialsProvider>::global(cx);
598 let state = self.state.clone();
599
600 cx.spawn(async move |_this, cx| {
601 // Delete from system keychain
602 credentials_provider
603 .delete_credentials(&credential_key, cx)
604 .await
605 .log_err();
606
607 // Update state to unauthenticated
608 cx.update(|cx| {
609 state.update(cx, |state, cx| {
610 state.is_authenticated = false;
611 cx.notify();
612 });
613 })
614 .log_err();
615 })
616 .detach();
617 }
618
619 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
620 if self.oauth_in_progress {
621 return;
622 }
623
624 self.oauth_in_progress = true;
625 self.oauth_error = None;
626 self.device_user_code = None;
627 cx.notify();
628
629 let extension = self.extension.clone();
630 let provider_id = self.extension_provider_id.clone();
631 let state = self.state.clone();
632
633 cx.spawn(async move |this, cx| {
634 // Step 1: Start device flow - opens browser and returns user code
635 let start_result = extension
636 .call({
637 let provider_id = provider_id.clone();
638 |ext, store| {
639 async move {
640 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
641 .await
642 }
643 .boxed()
644 }
645 })
646 .await;
647
648 let user_code = match start_result {
649 Ok(Ok(Ok(code))) => code,
650 Ok(Ok(Err(e))) => {
651 log::error!("Device flow start failed: {}", e);
652 this.update(cx, |this, cx| {
653 this.oauth_in_progress = false;
654 this.oauth_error = Some(e);
655 cx.notify();
656 })
657 .log_err();
658 return;
659 }
660 Ok(Err(e)) | Err(e) => {
661 log::error!("Device flow start error: {}", e);
662 this.update(cx, |this, cx| {
663 this.oauth_in_progress = false;
664 this.oauth_error = Some(e.to_string());
665 cx.notify();
666 })
667 .log_err();
668 return;
669 }
670 };
671
672 // Update UI to show the user code before polling
673 this.update(cx, |this, cx| {
674 this.device_user_code = Some(user_code);
675 cx.notify();
676 })
677 .log_err();
678
679 // Step 2: Poll for authentication completion
680 let poll_result = extension
681 .call({
682 let provider_id = provider_id.clone();
683 |ext, store| {
684 async move {
685 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
686 .await
687 }
688 .boxed()
689 }
690 })
691 .await;
692
693 let error_message = match poll_result {
694 Ok(Ok(Ok(()))) => {
695 // After successful auth, refresh the models list
696 let models_result = extension
697 .call({
698 let provider_id = provider_id.clone();
699 |ext, store| {
700 async move {
701 ext.call_llm_provider_models(store, &provider_id).await
702 }
703 .boxed()
704 }
705 })
706 .await;
707
708 let new_models: Vec<LlmModelInfo> = match models_result {
709 Ok(Ok(Ok(models))) => models,
710 _ => Vec::new(),
711 };
712
713 cx.update(|cx| {
714 state.update(cx, |state, cx| {
715 state.is_authenticated = true;
716 state.available_models = new_models;
717 cx.notify();
718 });
719 })
720 .log_err();
721 None
722 }
723 Ok(Ok(Err(e))) => {
724 log::error!("Device flow poll failed: {}", e);
725 Some(e)
726 }
727 Ok(Err(e)) | Err(e) => {
728 log::error!("Device flow poll error: {}", e);
729 Some(e.to_string())
730 }
731 };
732
733 this.update(cx, |this, cx| {
734 this.oauth_in_progress = false;
735 this.oauth_error = error_message;
736 this.device_user_code = None;
737 cx.notify();
738 })
739 .log_err();
740 })
741 .detach();
742 }
743
744 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
745 self.state.read(cx).is_authenticated
746 }
747
748 fn has_oauth_config(&self) -> bool {
749 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
750 }
751
752 fn oauth_config(&self) -> Option<&OAuthConfig> {
753 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
754 }
755
756 fn has_api_key_config(&self) -> bool {
757 // API key is available if there's a credential_label or no oauth-only config
758 self.auth_config
759 .as_ref()
760 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
761 .unwrap_or(true)
762 }
763}
764
765impl gpui::Render for ExtensionProviderConfigurationView {
766 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
767 let is_loading = self.loading_settings || self.loading_credentials;
768 let is_authenticated = self.is_authenticated(cx);
769 let env_var_allowed = self.state.read(cx).env_var_allowed;
770 let api_key_from_env = self.state.read(cx).api_key_from_env;
771 let has_oauth = self.has_oauth_config();
772 let has_api_key = self.has_api_key_config();
773
774 if is_loading {
775 return v_flex()
776 .gap_2()
777 .child(Label::new("Loading...").color(Color::Muted))
778 .into_any_element();
779 }
780
781 let mut content = v_flex().gap_4().size_full();
782
783 // Render settings markdown if available
784 if let Some(markdown) = &self.settings_markdown {
785 let style = settings_markdown_style(_window, cx);
786 content = content.child(
787 div()
788 .p_2()
789 .rounded_md()
790 .bg(cx.theme().colors().surface_background)
791 .child(MarkdownElement::new(markdown.clone(), style)),
792 );
793 }
794
795 // Render env var checkbox if the extension specifies an env var
796 if let Some(auth_config) = &self.auth_config {
797 if let Some(env_var_name) = &auth_config.env_var {
798 let env_var_name = env_var_name.clone();
799 let checkbox_label =
800 format!("Read API key from {} environment variable", env_var_name);
801
802 content = content.child(
803 h_flex()
804 .gap_2()
805 .child(
806 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
807 .on_click(cx.listener(|this, _, _window, cx| {
808 this.toggle_env_var_permission(cx);
809 })),
810 )
811 .child(Label::new(checkbox_label).size(LabelSize::Small)),
812 );
813
814 // Show status if env var is allowed
815 if env_var_allowed {
816 if api_key_from_env {
817 content = content.child(
818 h_flex()
819 .gap_2()
820 .child(
821 ui::Icon::new(ui::IconName::Check)
822 .color(Color::Success)
823 .size(ui::IconSize::Small),
824 )
825 .child(
826 Label::new(format!("API key loaded from {}", env_var_name))
827 .color(Color::Success),
828 ),
829 );
830 return content.into_any_element();
831 } else {
832 content = content.child(
833 h_flex()
834 .gap_2()
835 .child(
836 ui::Icon::new(ui::IconName::Warning)
837 .color(Color::Warning)
838 .size(ui::IconSize::Small),
839 )
840 .child(
841 Label::new(format!(
842 "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
843 env_var_name
844 ))
845 .color(Color::Warning)
846 .size(LabelSize::Small),
847 ),
848 );
849 }
850 }
851 }
852 }
853
854 // If authenticated, show success state with sign out option
855 if is_authenticated && !api_key_from_env {
856 let reset_label = if has_oauth && !has_api_key {
857 "Sign Out"
858 } else {
859 "Reset Credentials"
860 };
861
862 let status_label = if has_oauth && !has_api_key {
863 "Signed in"
864 } else {
865 "Authenticated"
866 };
867
868 content = content.child(
869 v_flex()
870 .gap_2()
871 .child(
872 h_flex()
873 .gap_2()
874 .child(
875 ui::Icon::new(ui::IconName::Check)
876 .color(Color::Success)
877 .size(ui::IconSize::Small),
878 )
879 .child(Label::new(status_label).color(Color::Success)),
880 )
881 .child(
882 ui::Button::new("reset-credentials", reset_label)
883 .style(ui::ButtonStyle::Subtle)
884 .on_click(cx.listener(|this, _, window, cx| {
885 this.reset_api_key(window, cx);
886 })),
887 ),
888 );
889
890 return content.into_any_element();
891 }
892
893 // Not authenticated - show available auth options
894 if !api_key_from_env {
895 // Render OAuth sign-in button if configured
896 if has_oauth {
897 let oauth_config = self.oauth_config();
898 let button_label = oauth_config
899 .and_then(|c| c.sign_in_button_label.clone())
900 .unwrap_or_else(|| "Sign In".to_string());
901 let button_icon = oauth_config
902 .and_then(|c| c.sign_in_button_icon.as_ref())
903 .and_then(|icon_name| match icon_name.as_str() {
904 "github" => Some(ui::IconName::Github),
905 _ => None,
906 });
907
908 let oauth_in_progress = self.oauth_in_progress;
909
910 let oauth_error = self.oauth_error.clone();
911
912 content = content.child(
913 v_flex()
914 .gap_2()
915 .child({
916 let mut button = ui::Button::new("oauth-sign-in", button_label)
917 .full_width()
918 .style(ui::ButtonStyle::Outlined)
919 .disabled(oauth_in_progress)
920 .on_click(cx.listener(|this, _, _window, cx| {
921 this.start_oauth_sign_in(cx);
922 }));
923 if let Some(icon) = button_icon {
924 button = button
925 .icon(icon)
926 .icon_position(ui::IconPosition::Start)
927 .icon_size(ui::IconSize::Small)
928 .icon_color(Color::Muted);
929 }
930 button
931 })
932 .when(oauth_in_progress, |this| {
933 let user_code = self.device_user_code.clone();
934 this.child(
935 v_flex()
936 .gap_1()
937 .when_some(user_code, |this, code| {
938 let copied = cx
939 .read_from_clipboard()
940 .map(|item| item.text().as_ref() == Some(&code))
941 .unwrap_or(false);
942 let code_for_click = code.clone();
943 this.child(
944 h_flex()
945 .gap_1()
946 .child(
947 Label::new("Enter code:")
948 .size(LabelSize::Small)
949 .color(Color::Muted),
950 )
951 .child(
952 h_flex()
953 .gap_1()
954 .px_1()
955 .border_1()
956 .border_color(cx.theme().colors().border)
957 .rounded_sm()
958 .cursor_pointer()
959 .on_mouse_down(
960 MouseButton::Left,
961 move |_, window, cx| {
962 cx.write_to_clipboard(
963 ClipboardItem::new_string(
964 code_for_click.clone(),
965 ),
966 );
967 window.refresh();
968 },
969 )
970 .child(
971 Label::new(code)
972 .size(LabelSize::Small)
973 .color(Color::Accent),
974 )
975 .child(
976 ui::Icon::new(if copied {
977 ui::IconName::Check
978 } else {
979 ui::IconName::Copy
980 })
981 .size(ui::IconSize::Small)
982 .color(if copied {
983 Color::Success
984 } else {
985 Color::Muted
986 }),
987 ),
988 ),
989 )
990 })
991 .child(
992 Label::new("Waiting for authorization in browser...")
993 .size(LabelSize::Small)
994 .color(Color::Muted),
995 ),
996 )
997 })
998 .when_some(oauth_error, |this, error| {
999 this.child(
1000 v_flex()
1001 .gap_1()
1002 .child(
1003 h_flex()
1004 .gap_2()
1005 .child(
1006 ui::Icon::new(ui::IconName::Warning)
1007 .color(Color::Error)
1008 .size(ui::IconSize::Small),
1009 )
1010 .child(
1011 Label::new("Authentication failed")
1012 .color(Color::Error)
1013 .size(LabelSize::Small),
1014 ),
1015 )
1016 .child(
1017 div().pl_6().child(
1018 Label::new(error)
1019 .color(Color::Error)
1020 .size(LabelSize::Small),
1021 ),
1022 ),
1023 )
1024 }),
1025 );
1026 }
1027
1028 // Render API key input if configured (and we have both options, show a separator)
1029 if has_api_key {
1030 if has_oauth {
1031 content = content.child(
1032 h_flex()
1033 .gap_2()
1034 .items_center()
1035 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1036 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1037 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1038 );
1039 }
1040
1041 let credential_label = self
1042 .auth_config
1043 .as_ref()
1044 .and_then(|c| c.credential_label.clone())
1045 .unwrap_or_else(|| "API Key".to_string());
1046
1047 content = content.child(
1048 v_flex()
1049 .gap_2()
1050 .on_action(cx.listener(Self::save_api_key))
1051 .child(
1052 Label::new(credential_label)
1053 .size(LabelSize::Small)
1054 .color(Color::Muted),
1055 )
1056 .child(self.api_key_editor.clone())
1057 .child(
1058 Label::new("Enter your API key and press Enter to save")
1059 .size(LabelSize::Small)
1060 .color(Color::Muted),
1061 ),
1062 );
1063 }
1064 }
1065
1066 content.into_any_element()
1067 }
1068}
1069
1070impl Focusable for ExtensionProviderConfigurationView {
1071 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1072 self.api_key_editor.focus_handle(cx)
1073 }
1074}
1075
1076fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1077 let theme_settings = ThemeSettings::get_global(cx);
1078 let colors = cx.theme().colors();
1079 let mut text_style = window.text_style();
1080 text_style.refine(&TextStyleRefinement {
1081 font_family: Some(theme_settings.ui_font.family.clone()),
1082 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1083 font_features: Some(theme_settings.ui_font.features.clone()),
1084 color: Some(colors.text),
1085 ..Default::default()
1086 });
1087
1088 MarkdownStyle {
1089 base_text_style: text_style,
1090 selection_background_color: colors.element_selection_background,
1091 inline_code: TextStyleRefinement {
1092 background_color: Some(colors.editor_background),
1093 ..Default::default()
1094 },
1095 link: TextStyleRefinement {
1096 color: Some(colors.text_accent),
1097 underline: Some(UnderlineStyle {
1098 color: Some(colors.text_accent.opacity(0.5)),
1099 thickness: px(1.),
1100 ..Default::default()
1101 }),
1102 ..Default::default()
1103 },
1104 syntax: cx.theme().syntax().clone(),
1105 ..Default::default()
1106 }
1107}
1108
1109/// An extension-based language model.
1110pub struct ExtensionLanguageModel {
1111 extension: WasmExtension,
1112 model_info: LlmModelInfo,
1113 provider_id: LanguageModelProviderId,
1114 provider_name: LanguageModelProviderName,
1115 provider_info: LlmProviderInfo,
1116}
1117
1118impl LanguageModel for ExtensionLanguageModel {
1119 fn id(&self) -> LanguageModelId {
1120 LanguageModelId::from(self.model_info.id.clone())
1121 }
1122
1123 fn name(&self) -> LanguageModelName {
1124 LanguageModelName::from(self.model_info.name.clone())
1125 }
1126
1127 fn provider_id(&self) -> LanguageModelProviderId {
1128 self.provider_id.clone()
1129 }
1130
1131 fn provider_name(&self) -> LanguageModelProviderName {
1132 self.provider_name.clone()
1133 }
1134
1135 fn telemetry_id(&self) -> String {
1136 format!("extension-{}", self.model_info.id)
1137 }
1138
1139 fn supports_images(&self) -> bool {
1140 self.model_info.capabilities.supports_images
1141 }
1142
1143 fn supports_tools(&self) -> bool {
1144 self.model_info.capabilities.supports_tools
1145 }
1146
1147 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1148 match choice {
1149 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1150 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1151 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1152 }
1153 }
1154
1155 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1156 match self.model_info.capabilities.tool_input_format {
1157 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1158 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1159 }
1160 }
1161
1162 fn max_token_count(&self) -> u64 {
1163 self.model_info.max_token_count
1164 }
1165
1166 fn max_output_tokens(&self) -> Option<u64> {
1167 self.model_info.max_output_tokens
1168 }
1169
1170 fn count_tokens(
1171 &self,
1172 request: LanguageModelRequest,
1173 cx: &App,
1174 ) -> BoxFuture<'static, Result<u64>> {
1175 let extension = self.extension.clone();
1176 let provider_id = self.provider_info.id.clone();
1177 let model_id = self.model_info.id.clone();
1178
1179 let wit_request = convert_request_to_wit(request);
1180
1181 cx.background_spawn(async move {
1182 extension
1183 .call({
1184 let provider_id = provider_id.clone();
1185 let model_id = model_id.clone();
1186 let wit_request = wit_request.clone();
1187 |ext, store| {
1188 async move {
1189 let count = ext
1190 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1191 .await?
1192 .map_err(|e| anyhow!("{}", e))?;
1193 Ok(count)
1194 }
1195 .boxed()
1196 }
1197 })
1198 .await?
1199 })
1200 .boxed()
1201 }
1202
1203 fn stream_completion(
1204 &self,
1205 request: LanguageModelRequest,
1206 _cx: &AsyncApp,
1207 ) -> BoxFuture<
1208 'static,
1209 Result<
1210 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1211 LanguageModelCompletionError,
1212 >,
1213 > {
1214 let extension = self.extension.clone();
1215 let provider_id = self.provider_info.id.clone();
1216 let model_id = self.model_info.id.clone();
1217
1218 let wit_request = convert_request_to_wit(request);
1219
1220 async move {
1221 // Start the stream
1222 let stream_id_result = extension
1223 .call({
1224 let provider_id = provider_id.clone();
1225 let model_id = model_id.clone();
1226 let wit_request = wit_request.clone();
1227 |ext, store| {
1228 async move {
1229 let id = ext
1230 .call_llm_stream_completion_start(
1231 store,
1232 &provider_id,
1233 &model_id,
1234 &wit_request,
1235 )
1236 .await?
1237 .map_err(|e| anyhow!("{}", e))?;
1238 Ok(id)
1239 }
1240 .boxed()
1241 }
1242 })
1243 .await;
1244
1245 let stream_id = stream_id_result
1246 .map_err(LanguageModelCompletionError::Other)?
1247 .map_err(LanguageModelCompletionError::Other)?;
1248
1249 // Create a stream that polls for events
1250 let stream = futures::stream::unfold(
1251 (extension.clone(), stream_id, false),
1252 move |(extension, stream_id, done)| async move {
1253 if done {
1254 return None;
1255 }
1256
1257 let result = extension
1258 .call({
1259 let stream_id = stream_id.clone();
1260 |ext, store| {
1261 async move {
1262 let event = ext
1263 .call_llm_stream_completion_next(store, &stream_id)
1264 .await?
1265 .map_err(|e| anyhow!("{}", e))?;
1266 Ok(event)
1267 }
1268 .boxed()
1269 }
1270 })
1271 .await
1272 .and_then(|inner| inner);
1273
1274 match result {
1275 Ok(Some(event)) => {
1276 let converted = convert_completion_event(event);
1277 let is_done =
1278 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1279 Some((converted, (extension, stream_id, is_done)))
1280 }
1281 Ok(None) => {
1282 // Stream complete, close it
1283 let _ = extension
1284 .call({
1285 let stream_id = stream_id.clone();
1286 |ext, store| {
1287 async move {
1288 ext.call_llm_stream_completion_close(store, &stream_id)
1289 .await?;
1290 Ok::<(), anyhow::Error>(())
1291 }
1292 .boxed()
1293 }
1294 })
1295 .await;
1296 None
1297 }
1298 Err(e) => Some((
1299 Err(LanguageModelCompletionError::Other(e)),
1300 (extension, stream_id, true),
1301 )),
1302 }
1303 },
1304 );
1305
1306 Ok(stream.boxed())
1307 }
1308 .boxed()
1309 }
1310
1311 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1312 // Extensions can implement this via llm_cache_configuration
1313 None
1314 }
1315}
1316
1317fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1318 use language_model::{MessageContent, Role};
1319
1320 let messages: Vec<LlmRequestMessage> = request
1321 .messages
1322 .into_iter()
1323 .map(|msg| {
1324 let role = match msg.role {
1325 Role::User => LlmMessageRole::User,
1326 Role::Assistant => LlmMessageRole::Assistant,
1327 Role::System => LlmMessageRole::System,
1328 };
1329
1330 let content: Vec<LlmMessageContent> = msg
1331 .content
1332 .into_iter()
1333 .map(|c| match c {
1334 MessageContent::Text(text) => LlmMessageContent::Text(text),
1335 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1336 source: image.source.to_string(),
1337 width: Some(image.size.width.0 as u32),
1338 height: Some(image.size.height.0 as u32),
1339 }),
1340 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1341 id: tool_use.id.to_string(),
1342 name: tool_use.name.to_string(),
1343 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1344 thought_signature: tool_use.thought_signature,
1345 }),
1346 MessageContent::ToolResult(tool_result) => {
1347 let content = match tool_result.content {
1348 language_model::LanguageModelToolResultContent::Text(text) => {
1349 LlmToolResultContent::Text(text.to_string())
1350 }
1351 language_model::LanguageModelToolResultContent::Image(image) => {
1352 LlmToolResultContent::Image(LlmImageData {
1353 source: image.source.to_string(),
1354 width: Some(image.size.width.0 as u32),
1355 height: Some(image.size.height.0 as u32),
1356 })
1357 }
1358 };
1359 LlmMessageContent::ToolResult(LlmToolResult {
1360 tool_use_id: tool_result.tool_use_id.to_string(),
1361 tool_name: tool_result.tool_name.to_string(),
1362 is_error: tool_result.is_error,
1363 content,
1364 })
1365 }
1366 MessageContent::Thinking { text, signature } => {
1367 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1368 }
1369 MessageContent::RedactedThinking(data) => {
1370 LlmMessageContent::RedactedThinking(data)
1371 }
1372 })
1373 .collect();
1374
1375 LlmRequestMessage {
1376 role,
1377 content,
1378 cache: msg.cache,
1379 }
1380 })
1381 .collect();
1382
1383 let tools: Vec<LlmToolDefinition> = request
1384 .tools
1385 .into_iter()
1386 .map(|tool| LlmToolDefinition {
1387 name: tool.name,
1388 description: tool.description,
1389 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1390 })
1391 .collect();
1392
1393 let tool_choice = request.tool_choice.map(|tc| match tc {
1394 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1395 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1396 LanguageModelToolChoice::None => LlmToolChoice::None,
1397 });
1398
1399 LlmCompletionRequest {
1400 messages,
1401 tools,
1402 tool_choice,
1403 stop_sequences: request.stop,
1404 temperature: request.temperature,
1405 thinking_allowed: false,
1406 max_tokens: None,
1407 }
1408}
1409
1410fn convert_completion_event(
1411 event: LlmCompletionEvent,
1412) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1413 match event {
1414 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1415 message_id: String::new(),
1416 }),
1417 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1418 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1419 text: thinking.text,
1420 signature: thinking.signature,
1421 }),
1422 LlmCompletionEvent::RedactedThinking(data) => {
1423 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1424 }
1425 LlmCompletionEvent::ToolUse(tool_use) => {
1426 let raw_input = tool_use.input.clone();
1427 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1428 Ok(LanguageModelCompletionEvent::ToolUse(
1429 LanguageModelToolUse {
1430 id: LanguageModelToolUseId::from(tool_use.id),
1431 name: tool_use.name.into(),
1432 raw_input,
1433 input,
1434 is_input_complete: true,
1435 thought_signature: tool_use.thought_signature,
1436 },
1437 ))
1438 }
1439 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1440 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1441 id: LanguageModelToolUseId::from(error.id),
1442 tool_name: error.tool_name.into(),
1443 raw_input: error.raw_input.into(),
1444 json_parse_error: error.error,
1445 })
1446 }
1447 LlmCompletionEvent::Stop(reason) => {
1448 let stop_reason = match reason {
1449 LlmStopReason::EndTurn => StopReason::EndTurn,
1450 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1451 LlmStopReason::ToolUse => StopReason::ToolUse,
1452 LlmStopReason::Refusal => StopReason::Refusal,
1453 };
1454 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1455 }
1456 LlmCompletionEvent::Usage(usage) => {
1457 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1458 input_tokens: usage.input_tokens,
1459 output_tokens: usage.output_tokens,
1460 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1461 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1462 }))
1463 }
1464 LlmCompletionEvent::ReasoningDetails(json) => {
1465 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1466 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1467 ))
1468 }
1469 }
1470}