1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::{LanguageModelAuthConfig, OAuthConfig};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
20 MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 // First check cached state
183 if self.state.read(cx).is_authenticated {
184 return true;
185 }
186
187 // Also check env var dynamically (in case migration happened after provider creation)
188 if let Some(ref auth_config) = self.auth_config {
189 if let Some(ref env_var_name) = auth_config.env_var {
190 let provider_id_string = self.provider_id_string();
191 let env_var_allowed = ExtensionSettings::get_global(cx)
192 .allowed_env_var_providers
193 .contains(provider_id_string.as_str());
194
195 if env_var_allowed {
196 if let Ok(value) = std::env::var(env_var_name) {
197 if !value.is_empty() {
198 return true;
199 }
200 }
201 }
202 }
203 }
204
205 false
206 }
207
208 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
209 // Check if already authenticated via is_authenticated
210 if self.is_authenticated(cx) {
211 return Task::ready(Ok(()));
212 }
213
214 // Not authenticated - return error indicating credentials not found
215 Task::ready(Err(AuthenticateError::CredentialsNotFound))
216 }
217
218 fn configuration_view(
219 &self,
220 _target_agent: ConfigurationViewTargetAgent,
221 window: &mut Window,
222 cx: &mut App,
223 ) -> AnyView {
224 let credential_key = self.credential_key();
225 let extension = self.extension.clone();
226 let extension_provider_id = self.provider_info.id.clone();
227 let full_provider_id = self.provider_id_string();
228 let state = self.state.clone();
229 let auth_config = self.auth_config.clone();
230
231 cx.new(|cx| {
232 ExtensionProviderConfigurationView::new(
233 credential_key,
234 extension,
235 extension_provider_id,
236 full_provider_id,
237 auth_config,
238 state,
239 window,
240 cx,
241 )
242 })
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 let extension = self.extension.clone();
248 let provider_id = self.provider_info.id.clone();
249 let state = self.state.clone();
250 let credential_key = self.credential_key();
251
252 let credentials_provider = <dyn CredentialsProvider>::global(cx);
253
254 cx.spawn(async move |cx| {
255 // Delete from system keychain
256 credentials_provider
257 .delete_credentials(&credential_key, cx)
258 .await
259 .log_err();
260
261 // Call extension's reset_credentials
262 let result = extension
263 .call(|extension, store| {
264 async move {
265 extension
266 .call_llm_provider_reset_credentials(store, &provider_id)
267 .await
268 }
269 .boxed()
270 })
271 .await;
272
273 // Update state
274 cx.update(|cx| {
275 state.update(cx, |state, _| {
276 state.is_authenticated = false;
277 });
278 })?;
279
280 match result {
281 Ok(Ok(Ok(()))) => Ok(()),
282 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
283 Ok(Err(e)) => Err(e),
284 Err(e) => Err(e),
285 }
286 })
287 }
288}
289
290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
291 type ObservableEntity = ExtensionLlmProviderState;
292
293 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
294 Some(self.state.clone())
295 }
296
297 fn subscribe<T: 'static>(
298 &self,
299 cx: &mut Context<T>,
300 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
301 ) -> Option<Subscription> {
302 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
303 }
304}
305
306/// Configuration view for extension-based LLM providers.
307struct ExtensionProviderConfigurationView {
308 credential_key: String,
309 extension: WasmExtension,
310 extension_provider_id: String,
311 full_provider_id: String,
312 auth_config: Option<LanguageModelAuthConfig>,
313 state: Entity<ExtensionLlmProviderState>,
314 settings_markdown: Option<Entity<Markdown>>,
315 api_key_editor: Entity<Editor>,
316 loading_settings: bool,
317 loading_credentials: bool,
318 oauth_in_progress: bool,
319 oauth_error: Option<String>,
320 device_user_code: Option<String>,
321 _subscriptions: Vec<Subscription>,
322}
323
324impl ExtensionProviderConfigurationView {
325 fn new(
326 credential_key: String,
327 extension: WasmExtension,
328 extension_provider_id: String,
329 full_provider_id: String,
330 auth_config: Option<LanguageModelAuthConfig>,
331 state: Entity<ExtensionLlmProviderState>,
332 window: &mut Window,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 // Subscribe to state changes
336 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
337 cx.notify();
338 });
339
340 // Create API key editor
341 let api_key_editor = cx.new(|cx| {
342 let mut editor = Editor::single_line(window, cx);
343 editor.set_placeholder_text("Enter API key...", window, cx);
344 editor
345 });
346
347 let mut this = Self {
348 credential_key,
349 extension,
350 extension_provider_id,
351 full_provider_id,
352 auth_config,
353 state,
354 settings_markdown: None,
355 api_key_editor,
356 loading_settings: true,
357 loading_credentials: true,
358 oauth_in_progress: false,
359 oauth_error: None,
360 device_user_code: None,
361 _subscriptions: vec![state_subscription],
362 };
363
364 // Load settings text from extension
365 this.load_settings_text(cx);
366
367 // Load existing credentials
368 this.load_credentials(cx);
369
370 this
371 }
372
373 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
374 let extension = self.extension.clone();
375 let provider_id = self.extension_provider_id.clone();
376
377 cx.spawn(async move |this, cx| {
378 let result = extension
379 .call({
380 let provider_id = provider_id.clone();
381 |ext, store| {
382 async move {
383 ext.call_llm_provider_settings_markdown(store, &provider_id)
384 .await
385 }
386 .boxed()
387 }
388 })
389 .await;
390
391 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
392
393 this.update(cx, |this, cx| {
394 this.loading_settings = false;
395 if let Some(text) = settings_text {
396 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
397 this.settings_markdown = Some(markdown);
398 }
399 cx.notify();
400 })
401 .log_err();
402 })
403 .detach();
404 }
405
406 fn load_credentials(&mut self, cx: &mut Context<Self>) {
407 let credential_key = self.credential_key.clone();
408 let credentials_provider = <dyn CredentialsProvider>::global(cx);
409 let state = self.state.clone();
410
411 // Check if we should use env var (already set in state during provider construction)
412 let api_key_from_env = self.state.read(cx).api_key_from_env;
413
414 cx.spawn(async move |this, cx| {
415 // If using env var, we're already authenticated
416 if api_key_from_env {
417 this.update(cx, |this, cx| {
418 this.loading_credentials = false;
419 cx.notify();
420 })
421 .log_err();
422 return;
423 }
424
425 let credentials = credentials_provider
426 .read_credentials(&credential_key, cx)
427 .await
428 .log_err()
429 .flatten();
430
431 let has_credentials = credentials.is_some();
432
433 // Update authentication state based on stored credentials
434 cx.update(|cx| {
435 state.update(cx, |state, cx| {
436 state.is_authenticated = has_credentials;
437 cx.notify();
438 });
439 })
440 .log_err();
441
442 this.update(cx, |this, cx| {
443 this.loading_credentials = false;
444 cx.notify();
445 })
446 .log_err();
447 })
448 .detach();
449 }
450
451 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
452 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
453 let env_var_name = match &self.auth_config {
454 Some(config) => config.env_var.clone(),
455 None => return,
456 };
457
458 let state = self.state.clone();
459 let currently_allowed = self.state.read(cx).env_var_allowed;
460
461 // Update settings file
462 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
463 let providers = settings
464 .extension
465 .allowed_env_var_providers
466 .get_or_insert_with(Vec::new);
467
468 if currently_allowed {
469 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
470 } else {
471 if !providers
472 .iter()
473 .any(|id| id.as_ref() == full_provider_id.as_ref())
474 {
475 providers.push(full_provider_id.clone());
476 }
477 }
478 });
479
480 // Update local state
481 let new_allowed = !currently_allowed;
482 let new_from_env = if new_allowed {
483 if let Some(var_name) = &env_var_name {
484 if let Ok(value) = std::env::var(var_name) {
485 !value.is_empty()
486 } else {
487 false
488 }
489 } else {
490 false
491 }
492 } else {
493 false
494 };
495
496 state.update(cx, |state, cx| {
497 state.env_var_allowed = new_allowed;
498 state.api_key_from_env = new_from_env;
499 if new_from_env {
500 state.is_authenticated = true;
501 }
502 cx.notify();
503 });
504
505 // If env var is being enabled, clear any stored keychain credentials
506 // so there's only one source of truth for the API key
507 if new_allowed {
508 let credential_key = self.credential_key.clone();
509 let credentials_provider = <dyn CredentialsProvider>::global(cx);
510 cx.spawn(async move |_this, cx| {
511 credentials_provider
512 .delete_credentials(&credential_key, cx)
513 .await
514 .log_err();
515 })
516 .detach();
517 }
518
519 // If env var is being disabled, reload credentials from keychain
520 if !new_allowed {
521 self.reload_keychain_credentials(cx);
522 }
523
524 cx.notify();
525 }
526
527 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
528 let credential_key = self.credential_key.clone();
529 let credentials_provider = <dyn CredentialsProvider>::global(cx);
530 let state = self.state.clone();
531
532 cx.spawn(async move |_this, cx| {
533 let credentials = credentials_provider
534 .read_credentials(&credential_key, cx)
535 .await
536 .log_err()
537 .flatten();
538
539 let has_credentials = credentials.is_some();
540
541 cx.update(|cx| {
542 state.update(cx, |state, cx| {
543 state.is_authenticated = has_credentials;
544 cx.notify();
545 });
546 })
547 .log_err();
548 })
549 .detach();
550 }
551
552 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
553 let api_key = self.api_key_editor.read(cx).text(cx);
554 if api_key.is_empty() {
555 return;
556 }
557
558 // Clear the editor
559 self.api_key_editor
560 .update(cx, |editor, cx| editor.set_text("", window, cx));
561
562 let credential_key = self.credential_key.clone();
563 let credentials_provider = <dyn CredentialsProvider>::global(cx);
564 let state = self.state.clone();
565
566 cx.spawn(async move |_this, cx| {
567 // Store in system keychain
568 credentials_provider
569 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
570 .await
571 .log_err();
572
573 // Update state to authenticated
574 cx.update(|cx| {
575 state.update(cx, |state, cx| {
576 state.is_authenticated = true;
577 cx.notify();
578 });
579 })
580 .log_err();
581 })
582 .detach();
583 }
584
585 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
586 // Clear the editor
587 self.api_key_editor
588 .update(cx, |editor, cx| editor.set_text("", window, cx));
589
590 let credential_key = self.credential_key.clone();
591 let credentials_provider = <dyn CredentialsProvider>::global(cx);
592 let state = self.state.clone();
593
594 cx.spawn(async move |_this, cx| {
595 // Delete from system keychain
596 credentials_provider
597 .delete_credentials(&credential_key, cx)
598 .await
599 .log_err();
600
601 // Update state to unauthenticated
602 cx.update(|cx| {
603 state.update(cx, |state, cx| {
604 state.is_authenticated = false;
605 cx.notify();
606 });
607 })
608 .log_err();
609 })
610 .detach();
611 }
612
613 fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
614 if self.oauth_in_progress {
615 return;
616 }
617
618 self.oauth_in_progress = true;
619 self.oauth_error = None;
620 self.device_user_code = None;
621 cx.notify();
622
623 let extension = self.extension.clone();
624 let provider_id = self.extension_provider_id.clone();
625 let state = self.state.clone();
626
627 cx.spawn(async move |this, cx| {
628 // Step 1: Start device flow - opens browser and returns user code
629 let start_result = extension
630 .call({
631 let provider_id = provider_id.clone();
632 |ext, store| {
633 async move {
634 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
635 .await
636 }
637 .boxed()
638 }
639 })
640 .await;
641
642 let user_code = match start_result {
643 Ok(Ok(Ok(code))) => code,
644 Ok(Ok(Err(e))) => {
645 log::error!("Device flow start failed: {}", e);
646 this.update(cx, |this, cx| {
647 this.oauth_in_progress = false;
648 this.oauth_error = Some(e);
649 cx.notify();
650 })
651 .log_err();
652 return;
653 }
654 Ok(Err(e)) | Err(e) => {
655 log::error!("Device flow start error: {}", e);
656 this.update(cx, |this, cx| {
657 this.oauth_in_progress = false;
658 this.oauth_error = Some(e.to_string());
659 cx.notify();
660 })
661 .log_err();
662 return;
663 }
664 };
665
666 // Update UI to show the user code before polling
667 this.update(cx, |this, cx| {
668 this.device_user_code = Some(user_code);
669 cx.notify();
670 })
671 .log_err();
672
673 // Step 2: Poll for authentication completion
674 let poll_result = extension
675 .call({
676 let provider_id = provider_id.clone();
677 |ext, store| {
678 async move {
679 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
680 .await
681 }
682 .boxed()
683 }
684 })
685 .await;
686
687 let error_message = match poll_result {
688 Ok(Ok(Ok(()))) => {
689 // After successful auth, refresh the models list
690 let models_result = extension
691 .call({
692 let provider_id = provider_id.clone();
693 |ext, store| {
694 async move {
695 ext.call_llm_provider_models(store, &provider_id).await
696 }
697 .boxed()
698 }
699 })
700 .await;
701
702 let new_models: Vec<LlmModelInfo> = match models_result {
703 Ok(Ok(Ok(models))) => models,
704 _ => Vec::new(),
705 };
706
707 cx.update(|cx| {
708 state.update(cx, |state, cx| {
709 state.is_authenticated = true;
710 state.available_models = new_models;
711 cx.notify();
712 });
713 })
714 .log_err();
715 None
716 }
717 Ok(Ok(Err(e))) => {
718 log::error!("Device flow poll failed: {}", e);
719 Some(e)
720 }
721 Ok(Err(e)) | Err(e) => {
722 log::error!("Device flow poll error: {}", e);
723 Some(e.to_string())
724 }
725 };
726
727 this.update(cx, |this, cx| {
728 this.oauth_in_progress = false;
729 this.oauth_error = error_message;
730 this.device_user_code = None;
731 cx.notify();
732 })
733 .log_err();
734 })
735 .detach();
736 }
737
738 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
739 self.state.read(cx).is_authenticated
740 }
741
742 fn has_oauth_config(&self) -> bool {
743 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
744 }
745
746 fn oauth_config(&self) -> Option<&OAuthConfig> {
747 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
748 }
749
750 fn has_api_key_config(&self) -> bool {
751 // API key is available if there's a credential_label or no oauth-only config
752 self.auth_config
753 .as_ref()
754 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
755 .unwrap_or(true)
756 }
757}
758
759impl gpui::Render for ExtensionProviderConfigurationView {
760 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
761 let is_loading = self.loading_settings || self.loading_credentials;
762 let is_authenticated = self.is_authenticated(cx);
763 let env_var_allowed = self.state.read(cx).env_var_allowed;
764 let api_key_from_env = self.state.read(cx).api_key_from_env;
765 let has_oauth = self.has_oauth_config();
766 let has_api_key = self.has_api_key_config();
767
768 if is_loading {
769 return v_flex()
770 .gap_2()
771 .child(Label::new("Loading...").color(Color::Muted))
772 .into_any_element();
773 }
774
775 let mut content = v_flex().gap_4().size_full();
776
777 // Render settings markdown if available
778 if let Some(markdown) = &self.settings_markdown {
779 let style = settings_markdown_style(_window, cx);
780 content = content.child(
781 div()
782 .p_2()
783 .rounded_md()
784 .bg(cx.theme().colors().surface_background)
785 .child(MarkdownElement::new(markdown.clone(), style)),
786 );
787 }
788
789 // Render env var checkbox if the extension specifies an env var
790 if let Some(auth_config) = &self.auth_config {
791 if let Some(env_var_name) = &auth_config.env_var {
792 let env_var_name = env_var_name.clone();
793 let checkbox_label =
794 format!("Read API key from {} environment variable", env_var_name);
795
796 content = content.child(
797 h_flex()
798 .gap_2()
799 .child(
800 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
801 .on_click(cx.listener(|this, _, _window, cx| {
802 this.toggle_env_var_permission(cx);
803 })),
804 )
805 .child(Label::new(checkbox_label).size(LabelSize::Small)),
806 );
807
808 // Show status if env var is allowed
809 if env_var_allowed {
810 if api_key_from_env {
811 content = content.child(
812 h_flex()
813 .gap_2()
814 .child(
815 ui::Icon::new(ui::IconName::Check)
816 .color(Color::Success)
817 .size(ui::IconSize::Small),
818 )
819 .child(
820 Label::new(format!("API key loaded from {}", env_var_name))
821 .color(Color::Success),
822 ),
823 );
824 return content.into_any_element();
825 } else {
826 content = content.child(
827 h_flex()
828 .gap_2()
829 .child(
830 ui::Icon::new(ui::IconName::Warning)
831 .color(Color::Warning)
832 .size(ui::IconSize::Small),
833 )
834 .child(
835 Label::new(format!(
836 "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
837 env_var_name
838 ))
839 .color(Color::Warning)
840 .size(LabelSize::Small),
841 ),
842 );
843 }
844 }
845 }
846 }
847
848 // If authenticated, show success state with sign out option
849 if is_authenticated && !api_key_from_env {
850 let reset_label = if has_oauth && !has_api_key {
851 "Sign Out"
852 } else {
853 "Reset Credentials"
854 };
855
856 let status_label = if has_oauth && !has_api_key {
857 "Signed in"
858 } else {
859 "Authenticated"
860 };
861
862 content = content.child(
863 v_flex()
864 .gap_2()
865 .child(
866 h_flex()
867 .gap_2()
868 .child(
869 ui::Icon::new(ui::IconName::Check)
870 .color(Color::Success)
871 .size(ui::IconSize::Small),
872 )
873 .child(Label::new(status_label).color(Color::Success)),
874 )
875 .child(
876 ui::Button::new("reset-credentials", reset_label)
877 .style(ui::ButtonStyle::Subtle)
878 .on_click(cx.listener(|this, _, window, cx| {
879 this.reset_api_key(window, cx);
880 })),
881 ),
882 );
883
884 return content.into_any_element();
885 }
886
887 // Not authenticated - show available auth options
888 if !api_key_from_env {
889 // Render OAuth sign-in button if configured
890 if has_oauth {
891 let oauth_config = self.oauth_config();
892 let button_label = oauth_config
893 .and_then(|c| c.sign_in_button_label.clone())
894 .unwrap_or_else(|| "Sign In".to_string());
895 let button_icon = oauth_config
896 .and_then(|c| c.sign_in_button_icon.as_ref())
897 .and_then(|icon_name| match icon_name.as_str() {
898 "github" => Some(ui::IconName::Github),
899 _ => None,
900 });
901
902 let oauth_in_progress = self.oauth_in_progress;
903
904 let oauth_error = self.oauth_error.clone();
905
906 content = content.child(
907 v_flex()
908 .gap_2()
909 .child({
910 let mut button = ui::Button::new("oauth-sign-in", button_label)
911 .full_width()
912 .style(ui::ButtonStyle::Outlined)
913 .disabled(oauth_in_progress)
914 .on_click(cx.listener(|this, _, _window, cx| {
915 this.start_oauth_sign_in(cx);
916 }));
917 if let Some(icon) = button_icon {
918 button = button
919 .icon(icon)
920 .icon_position(ui::IconPosition::Start)
921 .icon_size(ui::IconSize::Small)
922 .icon_color(Color::Muted);
923 }
924 button
925 })
926 .when(oauth_in_progress, |this| {
927 let user_code = self.device_user_code.clone();
928 this.child(
929 v_flex()
930 .gap_1()
931 .when_some(user_code, |this, code| {
932 let copied = cx
933 .read_from_clipboard()
934 .map(|item| item.text().as_ref() == Some(&code))
935 .unwrap_or(false);
936 let code_for_click = code.clone();
937 this.child(
938 h_flex()
939 .gap_1()
940 .child(
941 Label::new("Enter code:")
942 .size(LabelSize::Small)
943 .color(Color::Muted),
944 )
945 .child(
946 h_flex()
947 .gap_1()
948 .px_1()
949 .border_1()
950 .border_color(cx.theme().colors().border)
951 .rounded_sm()
952 .cursor_pointer()
953 .on_mouse_down(
954 MouseButton::Left,
955 move |_, window, cx| {
956 cx.write_to_clipboard(
957 ClipboardItem::new_string(
958 code_for_click.clone(),
959 ),
960 );
961 window.refresh();
962 },
963 )
964 .child(
965 Label::new(code)
966 .size(LabelSize::Small)
967 .color(Color::Accent),
968 )
969 .child(
970 ui::Icon::new(if copied {
971 ui::IconName::Check
972 } else {
973 ui::IconName::Copy
974 })
975 .size(ui::IconSize::Small)
976 .color(if copied {
977 Color::Success
978 } else {
979 Color::Muted
980 }),
981 ),
982 ),
983 )
984 })
985 .child(
986 Label::new("Waiting for authorization in browser...")
987 .size(LabelSize::Small)
988 .color(Color::Muted),
989 ),
990 )
991 })
992 .when_some(oauth_error, |this, error| {
993 this.child(
994 v_flex()
995 .gap_1()
996 .child(
997 h_flex()
998 .gap_2()
999 .child(
1000 ui::Icon::new(ui::IconName::Warning)
1001 .color(Color::Error)
1002 .size(ui::IconSize::Small),
1003 )
1004 .child(
1005 Label::new("Authentication failed")
1006 .color(Color::Error)
1007 .size(LabelSize::Small),
1008 ),
1009 )
1010 .child(
1011 div().pl_6().child(
1012 Label::new(error)
1013 .color(Color::Error)
1014 .size(LabelSize::Small),
1015 ),
1016 ),
1017 )
1018 }),
1019 );
1020 }
1021
1022 // Render API key input if configured (and we have both options, show a separator)
1023 if has_api_key {
1024 if has_oauth {
1025 content = content.child(
1026 h_flex()
1027 .gap_2()
1028 .items_center()
1029 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1030 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1031 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1032 );
1033 }
1034
1035 let credential_label = self
1036 .auth_config
1037 .as_ref()
1038 .and_then(|c| c.credential_label.clone())
1039 .unwrap_or_else(|| "API Key".to_string());
1040
1041 content = content.child(
1042 v_flex()
1043 .gap_2()
1044 .on_action(cx.listener(Self::save_api_key))
1045 .child(
1046 Label::new(credential_label)
1047 .size(LabelSize::Small)
1048 .color(Color::Muted),
1049 )
1050 .child(self.api_key_editor.clone())
1051 .child(
1052 Label::new("Enter your API key and press Enter to save")
1053 .size(LabelSize::Small)
1054 .color(Color::Muted),
1055 ),
1056 );
1057 }
1058 }
1059
1060 content.into_any_element()
1061 }
1062}
1063
1064impl Focusable for ExtensionProviderConfigurationView {
1065 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1066 self.api_key_editor.focus_handle(cx)
1067 }
1068}
1069
1070fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1071 let theme_settings = ThemeSettings::get_global(cx);
1072 let colors = cx.theme().colors();
1073 let mut text_style = window.text_style();
1074 text_style.refine(&TextStyleRefinement {
1075 font_family: Some(theme_settings.ui_font.family.clone()),
1076 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1077 font_features: Some(theme_settings.ui_font.features.clone()),
1078 color: Some(colors.text),
1079 ..Default::default()
1080 });
1081
1082 MarkdownStyle {
1083 base_text_style: text_style,
1084 selection_background_color: colors.element_selection_background,
1085 inline_code: TextStyleRefinement {
1086 background_color: Some(colors.editor_background),
1087 ..Default::default()
1088 },
1089 link: TextStyleRefinement {
1090 color: Some(colors.text_accent),
1091 underline: Some(UnderlineStyle {
1092 color: Some(colors.text_accent.opacity(0.5)),
1093 thickness: px(1.),
1094 ..Default::default()
1095 }),
1096 ..Default::default()
1097 },
1098 syntax: cx.theme().syntax().clone(),
1099 ..Default::default()
1100 }
1101}
1102
1103/// An extension-based language model.
1104pub struct ExtensionLanguageModel {
1105 extension: WasmExtension,
1106 model_info: LlmModelInfo,
1107 provider_id: LanguageModelProviderId,
1108 provider_name: LanguageModelProviderName,
1109 provider_info: LlmProviderInfo,
1110}
1111
1112impl LanguageModel for ExtensionLanguageModel {
1113 fn id(&self) -> LanguageModelId {
1114 LanguageModelId::from(self.model_info.id.clone())
1115 }
1116
1117 fn name(&self) -> LanguageModelName {
1118 LanguageModelName::from(self.model_info.name.clone())
1119 }
1120
1121 fn provider_id(&self) -> LanguageModelProviderId {
1122 self.provider_id.clone()
1123 }
1124
1125 fn provider_name(&self) -> LanguageModelProviderName {
1126 self.provider_name.clone()
1127 }
1128
1129 fn telemetry_id(&self) -> String {
1130 format!("extension-{}", self.model_info.id)
1131 }
1132
1133 fn supports_images(&self) -> bool {
1134 self.model_info.capabilities.supports_images
1135 }
1136
1137 fn supports_tools(&self) -> bool {
1138 self.model_info.capabilities.supports_tools
1139 }
1140
1141 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1142 match choice {
1143 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1144 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1145 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1146 }
1147 }
1148
1149 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1150 match self.model_info.capabilities.tool_input_format {
1151 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1152 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1153 }
1154 }
1155
1156 fn max_token_count(&self) -> u64 {
1157 self.model_info.max_token_count
1158 }
1159
1160 fn max_output_tokens(&self) -> Option<u64> {
1161 self.model_info.max_output_tokens
1162 }
1163
1164 fn count_tokens(
1165 &self,
1166 request: LanguageModelRequest,
1167 cx: &App,
1168 ) -> BoxFuture<'static, Result<u64>> {
1169 let extension = self.extension.clone();
1170 let provider_id = self.provider_info.id.clone();
1171 let model_id = self.model_info.id.clone();
1172
1173 let wit_request = convert_request_to_wit(request);
1174
1175 cx.background_spawn(async move {
1176 extension
1177 .call({
1178 let provider_id = provider_id.clone();
1179 let model_id = model_id.clone();
1180 let wit_request = wit_request.clone();
1181 |ext, store| {
1182 async move {
1183 let count = ext
1184 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1185 .await?
1186 .map_err(|e| anyhow!("{}", e))?;
1187 Ok(count)
1188 }
1189 .boxed()
1190 }
1191 })
1192 .await?
1193 })
1194 .boxed()
1195 }
1196
1197 fn stream_completion(
1198 &self,
1199 request: LanguageModelRequest,
1200 _cx: &AsyncApp,
1201 ) -> BoxFuture<
1202 'static,
1203 Result<
1204 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1205 LanguageModelCompletionError,
1206 >,
1207 > {
1208 let extension = self.extension.clone();
1209 let provider_id = self.provider_info.id.clone();
1210 let model_id = self.model_info.id.clone();
1211
1212 let wit_request = convert_request_to_wit(request);
1213
1214 async move {
1215 // Start the stream
1216 let stream_id_result = extension
1217 .call({
1218 let provider_id = provider_id.clone();
1219 let model_id = model_id.clone();
1220 let wit_request = wit_request.clone();
1221 |ext, store| {
1222 async move {
1223 let id = ext
1224 .call_llm_stream_completion_start(
1225 store,
1226 &provider_id,
1227 &model_id,
1228 &wit_request,
1229 )
1230 .await?
1231 .map_err(|e| anyhow!("{}", e))?;
1232 Ok(id)
1233 }
1234 .boxed()
1235 }
1236 })
1237 .await;
1238
1239 let stream_id = stream_id_result
1240 .map_err(LanguageModelCompletionError::Other)?
1241 .map_err(LanguageModelCompletionError::Other)?;
1242
1243 // Create a stream that polls for events
1244 let stream = futures::stream::unfold(
1245 (extension.clone(), stream_id, false),
1246 move |(extension, stream_id, done)| async move {
1247 if done {
1248 return None;
1249 }
1250
1251 let result = extension
1252 .call({
1253 let stream_id = stream_id.clone();
1254 |ext, store| {
1255 async move {
1256 let event = ext
1257 .call_llm_stream_completion_next(store, &stream_id)
1258 .await?
1259 .map_err(|e| anyhow!("{}", e))?;
1260 Ok(event)
1261 }
1262 .boxed()
1263 }
1264 })
1265 .await
1266 .and_then(|inner| inner);
1267
1268 match result {
1269 Ok(Some(event)) => {
1270 let converted = convert_completion_event(event);
1271 let is_done =
1272 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1273 Some((converted, (extension, stream_id, is_done)))
1274 }
1275 Ok(None) => {
1276 // Stream complete, close it
1277 let _ = extension
1278 .call({
1279 let stream_id = stream_id.clone();
1280 |ext, store| {
1281 async move {
1282 ext.call_llm_stream_completion_close(store, &stream_id)
1283 .await?;
1284 Ok::<(), anyhow::Error>(())
1285 }
1286 .boxed()
1287 }
1288 })
1289 .await;
1290 None
1291 }
1292 Err(e) => Some((
1293 Err(LanguageModelCompletionError::Other(e)),
1294 (extension, stream_id, true),
1295 )),
1296 }
1297 },
1298 );
1299
1300 Ok(stream.boxed())
1301 }
1302 .boxed()
1303 }
1304
1305 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1306 // Extensions can implement this via llm_cache_configuration
1307 None
1308 }
1309}
1310
1311fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1312 use language_model::{MessageContent, Role};
1313
1314 let messages: Vec<LlmRequestMessage> = request
1315 .messages
1316 .into_iter()
1317 .map(|msg| {
1318 let role = match msg.role {
1319 Role::User => LlmMessageRole::User,
1320 Role::Assistant => LlmMessageRole::Assistant,
1321 Role::System => LlmMessageRole::System,
1322 };
1323
1324 let content: Vec<LlmMessageContent> = msg
1325 .content
1326 .into_iter()
1327 .map(|c| match c {
1328 MessageContent::Text(text) => LlmMessageContent::Text(text),
1329 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1330 source: image.source.to_string(),
1331 width: Some(image.size.width.0 as u32),
1332 height: Some(image.size.height.0 as u32),
1333 }),
1334 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1335 id: tool_use.id.to_string(),
1336 name: tool_use.name.to_string(),
1337 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1338 thought_signature: tool_use.thought_signature,
1339 }),
1340 MessageContent::ToolResult(tool_result) => {
1341 let content = match tool_result.content {
1342 language_model::LanguageModelToolResultContent::Text(text) => {
1343 LlmToolResultContent::Text(text.to_string())
1344 }
1345 language_model::LanguageModelToolResultContent::Image(image) => {
1346 LlmToolResultContent::Image(LlmImageData {
1347 source: image.source.to_string(),
1348 width: Some(image.size.width.0 as u32),
1349 height: Some(image.size.height.0 as u32),
1350 })
1351 }
1352 };
1353 LlmMessageContent::ToolResult(LlmToolResult {
1354 tool_use_id: tool_result.tool_use_id.to_string(),
1355 tool_name: tool_result.tool_name.to_string(),
1356 is_error: tool_result.is_error,
1357 content,
1358 })
1359 }
1360 MessageContent::Thinking { text, signature } => {
1361 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1362 }
1363 MessageContent::RedactedThinking(data) => {
1364 LlmMessageContent::RedactedThinking(data)
1365 }
1366 })
1367 .collect();
1368
1369 LlmRequestMessage {
1370 role,
1371 content,
1372 cache: msg.cache,
1373 }
1374 })
1375 .collect();
1376
1377 let tools: Vec<LlmToolDefinition> = request
1378 .tools
1379 .into_iter()
1380 .map(|tool| LlmToolDefinition {
1381 name: tool.name,
1382 description: tool.description,
1383 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1384 })
1385 .collect();
1386
1387 let tool_choice = request.tool_choice.map(|tc| match tc {
1388 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1389 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1390 LanguageModelToolChoice::None => LlmToolChoice::None,
1391 });
1392
1393 LlmCompletionRequest {
1394 messages,
1395 tools,
1396 tool_choice,
1397 stop_sequences: request.stop,
1398 temperature: request.temperature,
1399 thinking_allowed: false,
1400 max_tokens: None,
1401 }
1402}
1403
1404fn convert_completion_event(
1405 event: LlmCompletionEvent,
1406) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1407 match event {
1408 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1409 message_id: String::new(),
1410 }),
1411 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1412 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1413 text: thinking.text,
1414 signature: thinking.signature,
1415 }),
1416 LlmCompletionEvent::RedactedThinking(data) => {
1417 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1418 }
1419 LlmCompletionEvent::ToolUse(tool_use) => {
1420 let raw_input = tool_use.input.clone();
1421 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1422 Ok(LanguageModelCompletionEvent::ToolUse(
1423 LanguageModelToolUse {
1424 id: LanguageModelToolUseId::from(tool_use.id),
1425 name: tool_use.name.into(),
1426 raw_input,
1427 input,
1428 is_input_complete: true,
1429 thought_signature: tool_use.thought_signature,
1430 },
1431 ))
1432 }
1433 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1434 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1435 id: LanguageModelToolUseId::from(error.id),
1436 tool_name: error.tool_name.into(),
1437 raw_input: error.raw_input.into(),
1438 json_parse_error: error.error,
1439 })
1440 }
1441 LlmCompletionEvent::Stop(reason) => {
1442 let stop_reason = match reason {
1443 LlmStopReason::EndTurn => StopReason::EndTurn,
1444 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1445 LlmStopReason::ToolUse => StopReason::ToolUse,
1446 LlmStopReason::Refusal => StopReason::Refusal,
1447 };
1448 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1449 }
1450 LlmCompletionEvent::Usage(usage) => {
1451 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1452 input_tokens: usage.input_tokens,
1453 output_tokens: usage.output_tokens,
1454 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1455 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1456 }))
1457 }
1458 LlmCompletionEvent::ReasoningDetails(json) => {
1459 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1460 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1461 ))
1462 }
1463 }
1464}