1use crate::ExtensionSettings;
2use crate::LEGACY_LLM_EXTENSION_IDS;
3use crate::wasm_host::WasmExtension;
4use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
5use collections::HashSet;
6
7use crate::wasm_host::wit::{
8 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
9 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
10 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
11 LlmToolUse,
12};
13use anyhow::{Result, anyhow};
14use credentials_provider::CredentialsProvider;
15use extension::{LanguageModelAuthConfig, OAuthConfig};
16use futures::future::BoxFuture;
17use futures::stream::BoxStream;
18use futures::{FutureExt, StreamExt};
19
20use gpui::{
21 AnyView, App, AsyncApp, ClipboardItem, DismissEvent, Entity, EventEmitter, FocusHandle,
22 Focusable, MouseDownEvent, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window,
23 WindowBounds, WindowOptions, point, prelude::*, px,
24};
25use language_model::tool_schema::LanguageModelToolSchemaFormat;
26use language_model::{
27 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
28 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
29 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
30 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
31 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
32};
33use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
34use settings::Settings;
35use std::sync::Arc;
36use theme::ThemeSettings;
37use ui::{
38 ButtonLike, ButtonLink, Checkbox, ConfiguredApiCard, SpinnerLabel, ToggleState, Vector,
39 VectorName, prelude::*,
40};
41use ui_input::InputField;
42use util::ResultExt as _;
43use workspace::Workspace;
44use workspace::oauth_device_flow_modal::{
45 OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
46};
47
48/// An extension-based language model provider.
49pub struct ExtensionLanguageModelProvider {
50 pub extension: WasmExtension,
51 pub provider_info: LlmProviderInfo,
52 icon_path: Option<SharedString>,
53 auth_config: Option<LanguageModelAuthConfig>,
54 state: Entity<ExtensionLlmProviderState>,
55}
56
57pub struct ExtensionLlmProviderState {
58 is_authenticated: bool,
59 available_models: Vec<LlmModelInfo>,
60 /// Set of env var names that are allowed to be read for this provider.
61 allowed_env_vars: HashSet<String>,
62 /// If authenticated via env var, which one was used.
63 env_var_name_used: Option<String>,
64}
65
66impl EventEmitter<()> for ExtensionLlmProviderState {}
67
68impl ExtensionLanguageModelProvider {
69 pub fn new(
70 extension: WasmExtension,
71 provider_info: LlmProviderInfo,
72 models: Vec<LlmModelInfo>,
73 is_authenticated: bool,
74 icon_path: Option<SharedString>,
75 auth_config: Option<LanguageModelAuthConfig>,
76 cx: &mut App,
77 ) -> Self {
78 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
79
80 // Build set of allowed env vars for this provider
81 let settings = ExtensionSettings::get_global(cx);
82 let is_legacy_extension =
83 LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
84
85 let mut allowed_env_vars = HashSet::default();
86 if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
87 for env_var_name in env_vars {
88 let key = format!("{}:{}", provider_id_string, env_var_name);
89 // For legacy extensions, auto-allow if env var is set (migration will persist this)
90 let env_var_is_set = std::env::var(env_var_name)
91 .map(|v| !v.is_empty())
92 .unwrap_or(false);
93 if settings.allowed_env_var_providers.contains(key.as_str())
94 || (is_legacy_extension && env_var_is_set)
95 {
96 allowed_env_vars.insert(env_var_name.clone());
97 }
98 }
99 }
100
101 // Check if any allowed env var is set
102 let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
103 if let Ok(value) = std::env::var(env_var_name) {
104 if !value.is_empty() {
105 return Some(env_var_name.clone());
106 }
107 }
108 None
109 });
110
111 let is_authenticated = if env_var_name_used.is_some() {
112 true
113 } else {
114 is_authenticated
115 };
116
117 let state = cx.new(|_| ExtensionLlmProviderState {
118 is_authenticated,
119 available_models: models,
120 allowed_env_vars,
121 env_var_name_used,
122 });
123
124 Self {
125 extension,
126 provider_info,
127 icon_path,
128 auth_config,
129 state,
130 }
131 }
132
133 fn provider_id_string(&self) -> String {
134 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
135 }
136
137 /// The credential key used for storing the API key in the system keychain.
138 fn credential_key(&self) -> String {
139 format!("extension-llm-{}", self.provider_id_string())
140 }
141}
142
143impl LanguageModelProvider for ExtensionLanguageModelProvider {
144 fn id(&self) -> LanguageModelProviderId {
145 LanguageModelProviderId::from(self.provider_id_string())
146 }
147
148 fn name(&self) -> LanguageModelProviderName {
149 LanguageModelProviderName::from(self.provider_info.name.clone())
150 }
151
152 fn icon(&self) -> ui::IconName {
153 ui::IconName::ZedAssistant
154 }
155
156 fn icon_path(&self) -> Option<SharedString> {
157 self.icon_path.clone()
158 }
159
160 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
161 let state = self.state.read(cx);
162 state
163 .available_models
164 .iter()
165 .find(|m| m.is_default)
166 .or_else(|| state.available_models.first())
167 .map(|model_info| {
168 Arc::new(ExtensionLanguageModel {
169 extension: self.extension.clone(),
170 model_info: model_info.clone(),
171 provider_id: self.id(),
172 provider_name: self.name(),
173 provider_info: self.provider_info.clone(),
174 }) as Arc<dyn LanguageModel>
175 })
176 }
177
178 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
179 let state = self.state.read(cx);
180 state
181 .available_models
182 .iter()
183 .find(|m| m.is_default_fast)
184 .map(|model_info| {
185 Arc::new(ExtensionLanguageModel {
186 extension: self.extension.clone(),
187 model_info: model_info.clone(),
188 provider_id: self.id(),
189 provider_name: self.name(),
190 provider_info: self.provider_info.clone(),
191 }) as Arc<dyn LanguageModel>
192 })
193 }
194
195 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
196 let state = self.state.read(cx);
197 state
198 .available_models
199 .iter()
200 .map(|model_info| {
201 Arc::new(ExtensionLanguageModel {
202 extension: self.extension.clone(),
203 model_info: model_info.clone(),
204 provider_id: self.id(),
205 provider_name: self.name(),
206 provider_info: self.provider_info.clone(),
207 }) as Arc<dyn LanguageModel>
208 })
209 .collect()
210 }
211
212 fn is_authenticated(&self, cx: &App) -> bool {
213 // First check cached state
214 if self.state.read(cx).is_authenticated {
215 return true;
216 }
217
218 // Also check env var dynamically (in case settings changed after provider creation)
219 if let Some(ref auth_config) = self.auth_config {
220 if let Some(ref env_vars) = auth_config.env_vars {
221 let provider_id_string = self.provider_id_string();
222 let settings = ExtensionSettings::get_global(cx);
223 let is_legacy_extension =
224 LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
225
226 for env_var_name in env_vars {
227 let key = format!("{}:{}", provider_id_string, env_var_name);
228 // For legacy extensions, auto-allow if env var is set
229 let env_var_is_set = std::env::var(env_var_name)
230 .map(|v| !v.is_empty())
231 .unwrap_or(false);
232 if settings.allowed_env_var_providers.contains(key.as_str())
233 || (is_legacy_extension && env_var_is_set)
234 {
235 if let Ok(value) = std::env::var(env_var_name) {
236 if !value.is_empty() {
237 return true;
238 }
239 }
240 }
241 }
242 }
243 }
244
245 false
246 }
247
248 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
249 // Check if already authenticated via is_authenticated
250 if self.is_authenticated(cx) {
251 return Task::ready(Ok(()));
252 }
253
254 // Not authenticated - return error indicating credentials not found
255 Task::ready(Err(AuthenticateError::CredentialsNotFound))
256 }
257
258 fn configuration_view(
259 &self,
260 target_agent: ConfigurationViewTargetAgent,
261 window: &mut Window,
262 cx: &mut App,
263 ) -> AnyView {
264 let credential_key = self.credential_key();
265 let extension = self.extension.clone();
266 let extension_provider_id = self.provider_info.id.clone();
267 let full_provider_id = self.provider_id_string();
268 let state = self.state.clone();
269 let auth_config = self.auth_config.clone();
270
271 let icon_path = self.icon_path.clone();
272 cx.new(|cx| {
273 ExtensionProviderConfigurationView::new(
274 credential_key,
275 extension,
276 extension_provider_id,
277 full_provider_id,
278 auth_config,
279 state,
280 icon_path,
281 target_agent,
282 window,
283 cx,
284 )
285 })
286 .into()
287 }
288
289 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
290 let extension = self.extension.clone();
291 let provider_id = self.provider_info.id.clone();
292 let state = self.state.clone();
293 let credential_key = self.credential_key();
294
295 let credentials_provider = <dyn CredentialsProvider>::global(cx);
296
297 cx.spawn(async move |cx| {
298 // Delete from system keychain
299 credentials_provider
300 .delete_credentials(&credential_key, cx)
301 .await
302 .log_err();
303
304 // Call extension's reset_credentials
305 let result = extension
306 .call(|extension, store| {
307 async move {
308 extension
309 .call_llm_provider_reset_credentials(store, &provider_id)
310 .await
311 }
312 .boxed()
313 })
314 .await;
315
316 // Update state
317 cx.update(|cx| {
318 state.update(cx, |state, _| {
319 state.is_authenticated = false;
320 });
321 })?;
322
323 match result {
324 Ok(Ok(Ok(()))) => Ok(()),
325 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
326 Ok(Err(e)) => Err(e),
327 Err(e) => Err(e),
328 }
329 })
330 }
331}
332
333impl LanguageModelProviderState for ExtensionLanguageModelProvider {
334 type ObservableEntity = ExtensionLlmProviderState;
335
336 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
337 Some(self.state.clone())
338 }
339
340 fn subscribe<T: 'static>(
341 &self,
342 cx: &mut Context<T>,
343 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
344 ) -> Option<Subscription> {
345 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
346 }
347}
348
349/// Configuration view for extension-based LLM providers.
350struct ExtensionProviderConfigurationView {
351 credential_key: String,
352 extension: WasmExtension,
353 extension_provider_id: String,
354 full_provider_id: String,
355 auth_config: Option<LanguageModelAuthConfig>,
356 state: Entity<ExtensionLlmProviderState>,
357 settings_markdown: Option<Entity<Markdown>>,
358 api_key_editor: Entity<InputField>,
359 loading_settings: bool,
360 loading_credentials: bool,
361 oauth_in_progress: bool,
362 oauth_error: Option<String>,
363 icon_path: Option<SharedString>,
364 target_agent: ConfigurationViewTargetAgent,
365 _subscriptions: Vec<Subscription>,
366}
367
368impl ExtensionProviderConfigurationView {
369 fn new(
370 credential_key: String,
371 extension: WasmExtension,
372 extension_provider_id: String,
373 full_provider_id: String,
374 auth_config: Option<LanguageModelAuthConfig>,
375 state: Entity<ExtensionLlmProviderState>,
376 icon_path: Option<SharedString>,
377 target_agent: ConfigurationViewTargetAgent,
378 window: &mut Window,
379 cx: &mut Context<Self>,
380 ) -> Self {
381 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
382 cx.notify();
383 });
384
385 let credential_label = auth_config
386 .as_ref()
387 .and_then(|c| c.credential_label.clone())
388 .unwrap_or_else(|| "API Key".to_string());
389
390 let api_key_editor = cx.new(|cx| {
391 InputField::new(window, cx, "Enter API key and hit enter").label(credential_label)
392 });
393
394 let mut this = Self {
395 credential_key,
396 extension,
397 extension_provider_id,
398 full_provider_id,
399 auth_config,
400 state,
401 settings_markdown: None,
402 api_key_editor,
403 loading_settings: true,
404 loading_credentials: true,
405 oauth_in_progress: false,
406 oauth_error: None,
407 icon_path,
408 target_agent,
409 _subscriptions: vec![state_subscription],
410 };
411
412 this.load_settings_text(cx);
413 this.load_credentials(cx);
414 this
415 }
416
417 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
418 let extension = self.extension.clone();
419 let provider_id = self.extension_provider_id.clone();
420
421 cx.spawn(async move |this, cx| {
422 let result = extension
423 .call({
424 let provider_id = provider_id.clone();
425 |ext, store| {
426 async move {
427 ext.call_llm_provider_settings_markdown(store, &provider_id)
428 .await
429 }
430 .boxed()
431 }
432 })
433 .await;
434
435 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
436
437 this.update(cx, |this, cx| {
438 this.loading_settings = false;
439 if let Some(text) = settings_text {
440 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
441 this.settings_markdown = Some(markdown);
442 }
443 cx.notify();
444 })
445 .log_err();
446 })
447 .detach();
448 }
449
450 fn load_credentials(&mut self, cx: &mut Context<Self>) {
451 let credential_key = self.credential_key.clone();
452 let credentials_provider = <dyn CredentialsProvider>::global(cx);
453 let state = self.state.clone();
454
455 // Check if we should use env var (already set in state during provider construction)
456 let using_env_var = self.state.read(cx).env_var_name_used.is_some();
457
458 cx.spawn(async move |this, cx| {
459 // If using env var, we're already authenticated
460 if using_env_var {
461 this.update(cx, |this, cx| {
462 this.loading_credentials = false;
463 cx.notify();
464 })
465 .log_err();
466 return;
467 }
468
469 let credentials = credentials_provider
470 .read_credentials(&credential_key, cx)
471 .await
472 .log_err()
473 .flatten();
474
475 let has_credentials = credentials.is_some();
476
477 // Update authentication state based on stored credentials
478 cx.update(|cx| {
479 state.update(cx, |state, cx| {
480 state.is_authenticated = has_credentials;
481 cx.notify();
482 });
483 })
484 .log_err();
485
486 this.update(cx, |this, cx| {
487 this.loading_credentials = false;
488 cx.notify();
489 })
490 .log_err();
491 })
492 .detach();
493 }
494
495 fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
496 let full_provider_id = self.full_provider_id.clone();
497 let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
498
499 let state = self.state.clone();
500 let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
501
502 // Update settings file
503 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
504 move |settings, _| {
505 let allowed = settings
506 .extension
507 .allowed_env_var_providers
508 .get_or_insert_with(Vec::new);
509
510 if currently_allowed {
511 allowed.retain(|id| id.as_ref() != settings_key.as_ref());
512 } else {
513 if !allowed
514 .iter()
515 .any(|id| id.as_ref() == settings_key.as_ref())
516 {
517 allowed.push(settings_key.clone());
518 }
519 }
520 }
521 });
522
523 // Update local state
524 let new_allowed = !currently_allowed;
525
526 state.update(cx, |state, cx| {
527 if new_allowed {
528 state.allowed_env_vars.insert(env_var_name.clone());
529 // Check if this env var is set and update env_var_name_used
530 if let Ok(value) = std::env::var(&env_var_name) {
531 if !value.is_empty() && state.env_var_name_used.is_none() {
532 state.env_var_name_used = Some(env_var_name.clone());
533 state.is_authenticated = true;
534 }
535 }
536 } else {
537 state.allowed_env_vars.remove(&env_var_name);
538 // If this was the env var being used, clear it and find another
539 if state.env_var_name_used.as_ref() == Some(&env_var_name) {
540 state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
541 if let Ok(value) = std::env::var(var) {
542 if !value.is_empty() {
543 return Some(var.clone());
544 }
545 }
546 None
547 });
548 if state.env_var_name_used.is_none() {
549 // No env var auth available, need to check keychain
550 state.is_authenticated = false;
551 }
552 }
553 }
554 cx.notify();
555 });
556
557 // If all env vars are being disabled, reload credentials from keychain
558 if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
559 self.reload_keychain_credentials(cx);
560 }
561
562 cx.notify();
563 }
564
565 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
566 let credential_key = self.credential_key.clone();
567 let credentials_provider = <dyn CredentialsProvider>::global(cx);
568 let state = self.state.clone();
569
570 cx.spawn(async move |_this, cx| {
571 let credentials = credentials_provider
572 .read_credentials(&credential_key, cx)
573 .await
574 .log_err()
575 .flatten();
576
577 let has_credentials = credentials.is_some();
578
579 cx.update(|cx| {
580 state.update(cx, |state, cx| {
581 state.is_authenticated = has_credentials;
582 cx.notify();
583 });
584 })
585 .log_err();
586 })
587 .detach();
588 }
589
590 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
591 let api_key = self.api_key_editor.read(cx).text(cx);
592 if api_key.is_empty() {
593 return;
594 }
595
596 // Clear the editor
597 self.api_key_editor
598 .update(cx, |input, cx| input.clear(window, cx));
599
600 let credential_key = self.credential_key.clone();
601 let credentials_provider = <dyn CredentialsProvider>::global(cx);
602 let state = self.state.clone();
603
604 cx.spawn(async move |_this, cx| {
605 // Store in system keychain
606 credentials_provider
607 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
608 .await
609 .log_err();
610
611 // Update state to authenticated
612 cx.update(|cx| {
613 state.update(cx, |state, cx| {
614 state.is_authenticated = true;
615 cx.notify();
616 });
617 })
618 .log_err();
619 })
620 .detach();
621 }
622
623 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
624 // Clear the editor
625 self.api_key_editor
626 .update(cx, |input, cx| input.clear(window, cx));
627
628 let credential_key = self.credential_key.clone();
629 let credentials_provider = <dyn CredentialsProvider>::global(cx);
630 let state = self.state.clone();
631
632 cx.spawn(async move |_this, cx| {
633 // Delete from system keychain
634 credentials_provider
635 .delete_credentials(&credential_key, cx)
636 .await
637 .log_err();
638
639 // Update state to unauthenticated
640 cx.update(|cx| {
641 state.update(cx, |state, cx| {
642 state.is_authenticated = false;
643 cx.notify();
644 });
645 })
646 .log_err();
647 })
648 .detach();
649 }
650
651 fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
652 if self.oauth_in_progress {
653 return;
654 }
655
656 self.oauth_in_progress = true;
657 self.oauth_error = None;
658 cx.notify();
659
660 let extension = self.extension.clone();
661 let provider_id = self.extension_provider_id.clone();
662 let state = self.state.clone();
663 let icon_path = self.icon_path.clone();
664 let this_handle = cx.weak_entity();
665 let use_popup_window = self.is_edit_prediction_mode();
666
667 // Get current window bounds for positioning popup
668 let current_window_center = window.bounds().center();
669
670 // For workspace modal mode, find the workspace window
671 let workspace_window = if !use_popup_window {
672 log::info!("OAuth: Looking for workspace window");
673 let ws = window.window_handle().downcast::<Workspace>().or_else(|| {
674 log::info!("OAuth: Current window is not a workspace, searching other windows");
675 cx.windows()
676 .into_iter()
677 .find_map(|window_handle| window_handle.downcast::<Workspace>())
678 });
679
680 if ws.is_none() {
681 log::error!("OAuth: Could not find any workspace window");
682 self.oauth_in_progress = false;
683 self.oauth_error =
684 Some("Could not access workspace to show sign-in modal".to_string());
685 cx.notify();
686 return;
687 }
688 ws
689 } else {
690 None
691 };
692
693 log::info!(
694 "OAuth: Using {} mode",
695 if use_popup_window {
696 "popup window"
697 } else {
698 "workspace modal"
699 }
700 );
701 let state = state.downgrade();
702 cx.spawn(async move |_this, cx| {
703 // Step 1: Start device flow - get prompt info from extension
704 let start_result = extension
705 .call({
706 let provider_id = provider_id.clone();
707 |ext, store| {
708 async move {
709 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
710 .await
711 }
712 .boxed()
713 }
714 })
715 .await;
716
717 log::info!(
718 "OAuth: Device flow start result: {:?}",
719 start_result.is_ok()
720 );
721 let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
722 Ok(Ok(Ok(info))) => {
723 log::info!(
724 "OAuth: Got device flow prompt info, user_code: {}",
725 info.user_code
726 );
727 info
728 }
729 Ok(Ok(Err(e))) => {
730 log::error!("OAuth: Device flow start failed: {}", e);
731 this_handle
732 .update(cx, |this, cx| {
733 this.oauth_in_progress = false;
734 this.oauth_error = Some(e);
735 cx.notify();
736 })
737 .log_err();
738 return;
739 }
740 Ok(Err(e)) | Err(e) => {
741 log::error!("OAuth: Device flow start error: {}", e);
742 this_handle
743 .update(cx, |this, cx| {
744 this.oauth_in_progress = false;
745 this.oauth_error = Some(e.to_string());
746 cx.notify();
747 })
748 .log_err();
749 return;
750 }
751 };
752
753 // Step 2: Create state entity and show the modal/window
754 let modal_config = OAuthDeviceFlowModalConfig {
755 user_code: prompt_info.user_code,
756 verification_url: prompt_info.verification_url,
757 headline: prompt_info.headline,
758 description: prompt_info.description,
759 connect_button_label: prompt_info.connect_button_label,
760 success_headline: prompt_info.success_headline,
761 success_message: prompt_info.success_message,
762 icon_path,
763 };
764
765 let flow_state: Option<Entity<OAuthDeviceFlowState>> = if use_popup_window {
766 // Open a popup window like Copilot does
767 log::info!("OAuth: Opening popup window");
768 cx.update(|cx| {
769 let height = px(450.);
770 let width = px(350.);
771 let window_bounds = WindowBounds::Windowed(gpui::bounds(
772 current_window_center - point(height / 2.0, width / 2.0),
773 gpui::size(height, width),
774 ));
775
776 let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config.clone()));
777 let flow_state_for_window = flow_state.clone();
778
779 cx.open_window(
780 WindowOptions {
781 kind: gpui::WindowKind::PopUp,
782 window_bounds: Some(window_bounds),
783 is_resizable: false,
784 is_movable: true,
785 titlebar: Some(gpui::TitlebarOptions {
786 appears_transparent: true,
787 ..Default::default()
788 }),
789 ..Default::default()
790 },
791 |window, cx| {
792 cx.new(|cx| {
793 OAuthCodeVerificationWindow::new(
794 modal_config,
795 flow_state_for_window,
796 window,
797 cx,
798 )
799 })
800 },
801 )
802 .log_err();
803
804 Some(flow_state)
805 })
806 .ok()
807 .flatten()
808 } else {
809 // Use workspace modal
810 log::info!("OAuth: Attempting to show modal in workspace window");
811 workspace_window.as_ref().and_then(|ws| {
812 ws.update(cx, |workspace, window, cx| {
813 log::info!("OAuth: Inside workspace.update, creating modal");
814 window.activate_window();
815 let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
816 let flow_state_clone = flow_state.clone();
817 workspace.toggle_modal(window, cx, |_window, cx| {
818 log::info!("OAuth: Inside toggle_modal callback");
819 OAuthDeviceFlowModal::new(flow_state_clone, cx)
820 });
821 flow_state
822 })
823 .ok()
824 })
825 };
826
827 log::info!("OAuth: flow_state created: {:?}", flow_state.is_some());
828 let Some(flow_state) = flow_state else {
829 log::error!("OAuth: Failed to show sign-in modal/window");
830 this_handle
831 .update(cx, |this, cx| {
832 this.oauth_in_progress = false;
833 this.oauth_error = Some("Failed to show sign-in modal".to_string());
834 cx.notify();
835 })
836 .log_err();
837 return;
838 };
839 log::info!("OAuth: Modal/window shown successfully, starting poll");
840
841 // Step 3: Poll for authentication completion
842 let poll_result = extension
843 .call({
844 let provider_id = provider_id.clone();
845 |ext, store| {
846 async move {
847 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
848 .await
849 }
850 .boxed()
851 }
852 })
853 .await;
854
855 match poll_result {
856 Ok(Ok(Ok(()))) => {
857 // After successful auth, refresh the models list
858 let models_result = extension
859 .call({
860 let provider_id = provider_id.clone();
861 |ext, store| {
862 async move {
863 ext.call_llm_provider_models(store, &provider_id).await
864 }
865 .boxed()
866 }
867 })
868 .await;
869
870 let new_models: Vec<LlmModelInfo> = match models_result {
871 Ok(Ok(Ok(models))) => models,
872 _ => Vec::new(),
873 };
874
875 state
876 .update(cx, |state, cx| {
877 state.is_authenticated = true;
878 state.available_models = new_models;
879 cx.notify();
880 })
881 .log_err();
882
883 // Update flow state to show success
884 flow_state
885 .update(cx, |state, cx| {
886 state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
887 })
888 .log_err();
889 }
890 Ok(Ok(Err(e))) => {
891 log::error!("Device flow poll failed: {}", e);
892 flow_state
893 .update(cx, |state, cx| {
894 state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
895 })
896 .log_err();
897 this_handle
898 .update(cx, |this, cx| {
899 this.oauth_error = Some(e);
900 cx.notify();
901 })
902 .log_err();
903 }
904 Ok(Err(e)) | Err(e) => {
905 log::error!("Device flow poll error: {}", e);
906 let error_string = e.to_string();
907 flow_state
908 .update(cx, |state, cx| {
909 state.set_status(
910 OAuthDeviceFlowStatus::Failed(error_string.clone()),
911 cx,
912 );
913 })
914 .log_err();
915 this_handle
916 .update(cx, |this, cx| {
917 this.oauth_error = Some(error_string);
918 cx.notify();
919 })
920 .log_err();
921 }
922 };
923
924 this_handle
925 .update(cx, |this, cx| {
926 this.oauth_in_progress = false;
927 cx.notify();
928 })
929 .log_err();
930 })
931 .detach();
932 }
933
934 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
935 self.state.read(cx).is_authenticated
936 }
937
938 fn has_oauth_config(&self) -> bool {
939 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
940 }
941
942 fn oauth_config(&self) -> Option<&OAuthConfig> {
943 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
944 }
945
946 fn has_api_key_config(&self) -> bool {
947 // API key is available if there's a credential_label or no oauth-only config
948 self.auth_config
949 .as_ref()
950 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
951 .unwrap_or(true)
952 }
953
954 fn is_edit_prediction_mode(&self) -> bool {
955 self.target_agent == ConfigurationViewTargetAgent::EditPrediction
956 }
957
958 fn render_for_edit_prediction(
959 &mut self,
960 _window: &mut Window,
961 cx: &mut Context<Self>,
962 ) -> impl IntoElement {
963 let is_loading = self.loading_settings || self.loading_credentials;
964 let is_authenticated = self.is_authenticated(cx);
965 let has_oauth = self.has_oauth_config();
966
967 // Helper to create the horizontal container layout matching Copilot
968 let container = |description: SharedString, action: AnyElement| {
969 h_flex()
970 .pt_2p5()
971 .w_full()
972 .justify_between()
973 .child(
974 v_flex()
975 .w_full()
976 .max_w_1_2()
977 .child(Label::new("Authenticate To Use"))
978 .child(
979 Label::new(description)
980 .color(Color::Muted)
981 .size(LabelSize::Small),
982 ),
983 )
984 .child(action)
985 };
986
987 // Get the description from OAuth config or use a default
988 let oauth_config = self.oauth_config();
989 let description: SharedString = oauth_config
990 .and_then(|c| c.sign_in_description.clone())
991 .unwrap_or_else(|| "Sign in to authenticate with this provider.".to_string())
992 .into();
993
994 if is_loading {
995 return container(
996 description,
997 Button::new("loading", "Loading...")
998 .style(ButtonStyle::Outlined)
999 .disabled(true)
1000 .into_any_element(),
1001 )
1002 .into_any_element();
1003 }
1004
1005 // If authenticated, show the configured card
1006 if is_authenticated {
1007 let (status_label, button_label) = if has_oauth {
1008 ("Authorized", "Sign Out")
1009 } else {
1010 ("API key configured", "Reset Key")
1011 };
1012
1013 return ConfiguredApiCard::new(status_label)
1014 .button_label(button_label)
1015 .on_click(cx.listener(|this, _, window, cx| {
1016 this.reset_api_key(window, cx);
1017 }))
1018 .into_any_element();
1019 }
1020
1021 // Not authenticated - show sign in button
1022 if has_oauth {
1023 let button_label = oauth_config
1024 .and_then(|c| c.sign_in_button_label.clone())
1025 .unwrap_or_else(|| "Sign In".to_string());
1026 let button_icon = oauth_config
1027 .and_then(|c| c.sign_in_button_icon.as_ref())
1028 .and_then(|icon_name| match icon_name.as_str() {
1029 "github" => Some(ui::IconName::Github),
1030 _ => None,
1031 });
1032
1033 let oauth_in_progress = self.oauth_in_progress;
1034
1035 let mut button = Button::new("oauth-sign-in", button_label)
1036 .size(ButtonSize::Medium)
1037 .style(ButtonStyle::Outlined)
1038 .disabled(oauth_in_progress)
1039 .on_click(cx.listener(|this, _, window, cx| {
1040 this.start_oauth_sign_in(window, cx);
1041 }));
1042
1043 if let Some(icon) = button_icon {
1044 button = button
1045 .icon(icon)
1046 .icon_position(ui::IconPosition::Start)
1047 .icon_size(ui::IconSize::Small)
1048 .icon_color(Color::Muted);
1049 }
1050
1051 return container(description, button.into_any_element()).into_any_element();
1052 }
1053
1054 // Fallback for API key only providers - show a simple message
1055 container(
1056 description,
1057 Button::new("configure", "Configure")
1058 .size(ButtonSize::Medium)
1059 .style(ButtonStyle::Outlined)
1060 .disabled(true)
1061 .into_any_element(),
1062 )
1063 .into_any_element()
1064 }
1065}
1066
1067impl Render for ExtensionProviderConfigurationView {
1068 fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1069 if self.is_edit_prediction_mode() {
1070 return self
1071 .render_for_edit_prediction(window, cx)
1072 .into_any_element();
1073 }
1074
1075 let is_loading = self.loading_settings || self.loading_credentials;
1076 let is_authenticated = self.is_authenticated(cx);
1077 let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
1078 let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
1079 let has_oauth = self.has_oauth_config();
1080 let has_api_key = self.has_api_key_config();
1081
1082 if is_loading {
1083 return h_flex()
1084 .gap_2()
1085 .child(
1086 h_flex()
1087 .w_2()
1088 .child(SpinnerLabel::sand().size(LabelSize::Small)),
1089 )
1090 .child(LoadingLabel::new("Loading").size(LabelSize::Small))
1091 .into_any_element();
1092 }
1093
1094 let mut content = v_flex().size_full().gap_2();
1095
1096 if let Some(markdown) = &self.settings_markdown {
1097 content = content.text_sm().child(MarkdownElement::new(
1098 markdown.clone(),
1099 markdown_styles(window, cx),
1100 ));
1101 }
1102
1103 if let Some(auth_config) = &self.auth_config {
1104 if let Some(env_vars) = &auth_config.env_vars {
1105 for env_var_name in env_vars {
1106 let is_allowed = allowed_env_vars.contains(env_var_name);
1107 let checkbox_label =
1108 format!("Read API key from {} environment variable.", env_var_name);
1109 let env_var_for_click = env_var_name.clone();
1110
1111 content = content.child(
1112 Checkbox::new(
1113 SharedString::from(format!("env-var-{}", env_var_name)),
1114 if is_allowed {
1115 ToggleState::Selected
1116 } else {
1117 ToggleState::Unselected
1118 },
1119 )
1120 .label(checkbox_label)
1121 .on_click(cx.listener(
1122 move |this, _, _window, cx| {
1123 this.toggle_env_var_permission(env_var_for_click.clone(), cx);
1124 },
1125 )),
1126 );
1127 }
1128
1129 if let Some(used_var) = &env_var_name_used {
1130 content = content.child(
1131 ConfiguredApiCard::new(format!(
1132 "API key set in {} environment variable",
1133 used_var
1134 ))
1135 .tooltip_label(format!(
1136 "To reset this API key, unset the {} environment variable.",
1137 used_var
1138 ))
1139 .disabled(true),
1140 );
1141
1142 return content.into_any_element();
1143 }
1144 }
1145 }
1146
1147 if is_authenticated && env_var_name_used.is_none() {
1148 let (status_label, button_label) = if has_oauth && !has_api_key {
1149 ("Signed in", "Sign Out")
1150 } else {
1151 ("API key configured", "Reset Key")
1152 };
1153
1154 content = content.child(
1155 ConfiguredApiCard::new(status_label)
1156 .button_label(button_label)
1157 .on_click(cx.listener(|this, _, window, cx| {
1158 this.reset_api_key(window, cx);
1159 })),
1160 );
1161
1162 return content.into_any_element();
1163 }
1164
1165 // Not authenticated - show available auth options
1166 if env_var_name_used.is_none() {
1167 // Render OAuth sign-in button if configured
1168 if has_oauth {
1169 let oauth_config = self.oauth_config();
1170 let button_label = oauth_config
1171 .and_then(|c| c.sign_in_button_label.clone())
1172 .unwrap_or_else(|| "Sign In".to_string());
1173 let button_icon = oauth_config
1174 .and_then(|c| c.sign_in_button_icon.as_ref())
1175 .and_then(|icon_name| match icon_name.as_str() {
1176 "github" => Some(ui::IconName::Github),
1177 _ => None,
1178 });
1179
1180 let oauth_in_progress = self.oauth_in_progress;
1181
1182 let oauth_error = self.oauth_error.clone();
1183
1184 let mut button = Button::new("oauth-sign-in", button_label)
1185 .full_width()
1186 .style(ButtonStyle::Outlined)
1187 .disabled(oauth_in_progress)
1188 .on_click(cx.listener(|this, _, window, cx| {
1189 this.start_oauth_sign_in(window, cx);
1190 }));
1191 if let Some(icon) = button_icon {
1192 button = button
1193 .icon(icon)
1194 .icon_position(IconPosition::Start)
1195 .icon_size(IconSize::Small)
1196 .icon_color(Color::Muted);
1197 }
1198
1199 content = content.child(
1200 v_flex()
1201 .gap_2()
1202 .child(button)
1203 .when(oauth_in_progress, |this| {
1204 this.child(
1205 Label::new("Sign-in in progress...")
1206 .size(LabelSize::Small)
1207 .color(Color::Muted),
1208 )
1209 })
1210 .when_some(oauth_error, |this, error| {
1211 this.child(
1212 v_flex()
1213 .gap_1()
1214 .child(
1215 h_flex()
1216 .gap_2()
1217 .child(
1218 Icon::new(IconName::Warning)
1219 .color(Color::Error)
1220 .size(IconSize::Small),
1221 )
1222 .child(
1223 Label::new("Authentication failed")
1224 .color(Color::Error)
1225 .size(LabelSize::Small),
1226 ),
1227 )
1228 .child(
1229 div().pl_6().child(
1230 Label::new(error)
1231 .color(Color::Error)
1232 .size(LabelSize::Small),
1233 ),
1234 ),
1235 )
1236 }),
1237 );
1238 }
1239
1240 // Render API key input if configured (and we have both options, show a separator)
1241 if has_api_key {
1242 if has_oauth {
1243 content = content.child(
1244 h_flex()
1245 .gap_2()
1246 .items_center()
1247 .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant))
1248 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1249 .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant)),
1250 );
1251 }
1252
1253 content = content.child(
1254 div()
1255 .on_action(cx.listener(Self::save_api_key))
1256 .child(self.api_key_editor.clone()),
1257 );
1258 }
1259 }
1260
1261 if self.extension_provider_id == "openai" {
1262 content = content.child(
1263 h_flex()
1264 .gap_1()
1265 .child(
1266 Icon::new(IconName::Info)
1267 .size(IconSize::XSmall)
1268 .color(Color::Muted),
1269 )
1270 .child(
1271 Label::new("Zed also supports OpenAI-compatible models.")
1272 .size(LabelSize::Small)
1273 .color(Color::Muted),
1274 )
1275 .child(
1276 ButtonLink::new(
1277 "Learn More",
1278 "https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers",
1279 )
1280 .label_size(LabelSize::Small),
1281 ),
1282 );
1283 }
1284
1285 content.into_any_element()
1286 }
1287}
1288
1289impl Focusable for ExtensionProviderConfigurationView {
1290 fn focus_handle(&self, cx: &App) -> FocusHandle {
1291 self.api_key_editor.read(cx).focus_handle(cx)
1292 }
1293}
1294
1295/// A popup window for OAuth device flow, similar to CopilotCodeVerification.
1296/// This is used when in edit prediction mode to avoid moving the settings panel behind.
1297pub struct OAuthCodeVerificationWindow {
1298 config: OAuthDeviceFlowModalConfig,
1299 status: OAuthDeviceFlowStatus,
1300 connect_clicked: bool,
1301 focus_handle: FocusHandle,
1302 _subscription: Option<Subscription>,
1303}
1304
1305impl Focusable for OAuthCodeVerificationWindow {
1306 fn focus_handle(&self, _: &App) -> FocusHandle {
1307 self.focus_handle.clone()
1308 }
1309}
1310
1311impl EventEmitter<DismissEvent> for OAuthCodeVerificationWindow {}
1312
1313impl OAuthCodeVerificationWindow {
1314 pub fn new(
1315 config: OAuthDeviceFlowModalConfig,
1316 state: Entity<OAuthDeviceFlowState>,
1317 window: &mut Window,
1318 cx: &mut Context<Self>,
1319 ) -> Self {
1320 window.on_window_should_close(cx, |window, cx| {
1321 if let Some(this) = window.root::<OAuthCodeVerificationWindow>().flatten() {
1322 this.update(cx, |_, cx| {
1323 cx.emit(DismissEvent);
1324 });
1325 }
1326 true
1327 });
1328 cx.subscribe_in(
1329 &cx.entity(),
1330 window,
1331 |_, _, _: &DismissEvent, window, _cx| {
1332 window.remove_window();
1333 },
1334 )
1335 .detach();
1336
1337 let subscription = cx.observe(&state, |this, state, cx| {
1338 let status = state.read(cx).status.clone();
1339 this.status = status;
1340 cx.notify();
1341 });
1342
1343 Self {
1344 config,
1345 status: state.read(cx).status.clone(),
1346 connect_clicked: false,
1347 focus_handle: cx.focus_handle(),
1348 _subscription: Some(subscription),
1349 }
1350 }
1351
1352 fn render_icon(&self, cx: &mut Context<Self>) -> impl IntoElement {
1353 let icon_color = Color::Custom(cx.theme().colors().icon);
1354 let icon_size = rems(2.5);
1355 let plus_size = rems(0.875);
1356 let plus_color = cx.theme().colors().icon.opacity(0.5);
1357
1358 if let Some(icon_path) = &self.config.icon_path {
1359 h_flex()
1360 .gap_2()
1361 .items_center()
1362 .child(
1363 Icon::from_external_svg(icon_path.clone())
1364 .size(IconSize::Custom(icon_size))
1365 .color(icon_color),
1366 )
1367 .child(
1368 gpui::svg()
1369 .size(plus_size)
1370 .path("icons/plus.svg")
1371 .text_color(plus_color),
1372 )
1373 .child(Vector::new(VectorName::ZedLogo, icon_size, icon_size).color(icon_color))
1374 .into_any_element()
1375 } else {
1376 Vector::new(VectorName::ZedLogo, icon_size, icon_size)
1377 .color(icon_color)
1378 .into_any_element()
1379 }
1380 }
1381
1382 fn render_device_code(&self, cx: &mut Context<Self>) -> impl IntoElement {
1383 let user_code = self.config.user_code.clone();
1384 let copied = cx
1385 .read_from_clipboard()
1386 .map(|item| item.text().as_ref() == Some(&user_code))
1387 .unwrap_or(false);
1388 let user_code_for_click = user_code.clone();
1389
1390 ButtonLike::new("copy-button")
1391 .full_width()
1392 .style(ButtonStyle::Tinted(ui::TintColor::Accent))
1393 .size(ButtonSize::Medium)
1394 .child(
1395 h_flex()
1396 .w_full()
1397 .p_1()
1398 .justify_between()
1399 .child(Label::new(user_code))
1400 .child(Label::new(if copied { "Copied!" } else { "Copy" })),
1401 )
1402 .on_click(move |_, window, cx| {
1403 cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone()));
1404 window.refresh();
1405 })
1406 }
1407
1408 fn render_prompting_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1409 let connect_button_label: String = if self.connect_clicked {
1410 "Waiting for connection…".to_string()
1411 } else {
1412 self.config.connect_button_label.clone()
1413 };
1414 let verification_url = self.config.verification_url.clone();
1415
1416 v_flex()
1417 .flex_1()
1418 .gap_2p5()
1419 .items_center()
1420 .text_center()
1421 .child(Headline::new(self.config.headline.clone()).size(HeadlineSize::Large))
1422 .child(Label::new(self.config.description.clone()).color(Color::Muted))
1423 .child(self.render_device_code(cx))
1424 .child(
1425 Label::new("Paste this code after clicking the button below.").color(Color::Muted),
1426 )
1427 .child(
1428 v_flex()
1429 .w_full()
1430 .gap_1()
1431 .child(
1432 Button::new("connect-button", connect_button_label)
1433 .full_width()
1434 .style(ButtonStyle::Outlined)
1435 .size(ButtonSize::Medium)
1436 .on_click(cx.listener(move |this, _, _window, cx| {
1437 cx.open_url(&verification_url);
1438 this.connect_clicked = true;
1439 })),
1440 )
1441 .child(
1442 Button::new("cancel-button", "Cancel")
1443 .full_width()
1444 .size(ButtonSize::Medium)
1445 .on_click(cx.listener(|_, _, _, cx| {
1446 cx.emit(DismissEvent);
1447 })),
1448 ),
1449 )
1450 }
1451
1452 fn render_authorized_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1453 v_flex()
1454 .gap_2()
1455 .text_center()
1456 .justify_center()
1457 .child(Headline::new(self.config.success_headline.clone()).size(HeadlineSize::Large))
1458 .child(Label::new(self.config.success_message.clone()).color(Color::Muted))
1459 .child(
1460 Button::new("done-button", "Done")
1461 .full_width()
1462 .style(ButtonStyle::Outlined)
1463 .size(ButtonSize::Medium)
1464 .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1465 )
1466 }
1467
1468 fn render_failed_modal(&self, error: &str, cx: &mut Context<Self>) -> impl IntoElement {
1469 v_flex()
1470 .gap_2()
1471 .text_center()
1472 .justify_center()
1473 .child(Headline::new("Authorization Failed").size(HeadlineSize::Large))
1474 .child(Label::new(error.to_string()).color(Color::Error))
1475 .child(
1476 Button::new("close-button", "Close")
1477 .full_width()
1478 .size(ButtonSize::Medium)
1479 .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1480 )
1481 }
1482}
1483
1484impl Render for OAuthCodeVerificationWindow {
1485 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1486 let prompt = match &self.status {
1487 OAuthDeviceFlowStatus::Prompting | OAuthDeviceFlowStatus::WaitingForAuthorization => {
1488 self.render_prompting_modal(cx).into_any_element()
1489 }
1490 OAuthDeviceFlowStatus::Authorized => {
1491 self.render_authorized_modal(cx).into_any_element()
1492 }
1493 OAuthDeviceFlowStatus::Failed(error) => {
1494 self.render_failed_modal(error, cx).into_any_element()
1495 }
1496 };
1497
1498 v_flex()
1499 .id("oauth_code_verification")
1500 .track_focus(&self.focus_handle(cx))
1501 .size_full()
1502 .px_4()
1503 .py_8()
1504 .gap_2()
1505 .items_center()
1506 .justify_center()
1507 .elevation_3(cx)
1508 .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
1509 cx.emit(DismissEvent);
1510 }))
1511 .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _| {
1512 window.focus(&this.focus_handle);
1513 }))
1514 .child(self.render_icon(cx))
1515 .child(prompt)
1516 }
1517}
1518
1519fn markdown_styles(window: &Window, cx: &App) -> MarkdownStyle {
1520 let settings = ThemeSettings::get_global(cx);
1521 let colors = cx.theme().colors();
1522
1523 let mut text_style = window.text_style();
1524 text_style.refine(&TextStyleRefinement {
1525 font_family: Some(settings.ui_font.family.clone()),
1526 font_fallbacks: settings.ui_font.fallbacks.clone(),
1527 font_features: Some(settings.ui_font.features.clone()),
1528 font_size: Some(settings.ui_font_size(cx).into()),
1529 line_height: Some(relative(1.5)),
1530 color: Some(colors.text_muted),
1531 ..Default::default()
1532 });
1533
1534 MarkdownStyle {
1535 base_text_style: text_style.clone(),
1536 syntax: cx.theme().syntax().clone(),
1537 selection_background_color: colors.element_selection_background,
1538 heading_level_styles: Some(HeadingLevelStyles {
1539 h1: Some(TextStyleRefinement {
1540 font_size: Some(rems(1.15).into()),
1541 ..Default::default()
1542 }),
1543 h2: Some(TextStyleRefinement {
1544 font_size: Some(rems(1.1).into()),
1545 ..Default::default()
1546 }),
1547 h3: Some(TextStyleRefinement {
1548 font_size: Some(rems(1.05).into()),
1549 ..Default::default()
1550 }),
1551 h4: Some(TextStyleRefinement {
1552 font_size: Some(rems(1.).into()),
1553 ..Default::default()
1554 }),
1555 h5: Some(TextStyleRefinement {
1556 font_size: Some(rems(0.95).into()),
1557 ..Default::default()
1558 }),
1559 h6: Some(TextStyleRefinement {
1560 font_size: Some(rems(0.875).into()),
1561 ..Default::default()
1562 }),
1563 }),
1564 inline_code: TextStyleRefinement {
1565 font_family: Some(settings.buffer_font.family.clone()),
1566 font_fallbacks: settings.buffer_font.fallbacks.clone(),
1567 font_features: Some(settings.buffer_font.features.clone()),
1568 font_size: Some(settings.buffer_font_size(cx).into()),
1569 background_color: Some(colors.editor_foreground.opacity(0.08)),
1570 ..Default::default()
1571 },
1572 link: TextStyleRefinement {
1573 background_color: Some(colors.editor_foreground.opacity(0.025)),
1574 color: Some(colors.text_accent),
1575 underline: Some(UnderlineStyle {
1576 color: Some(colors.text_accent.opacity(0.5)),
1577 thickness: px(1.),
1578 ..Default::default()
1579 }),
1580 ..Default::default()
1581 },
1582 ..Default::default()
1583 }
1584}
1585
1586/// An extension-based language model.
1587pub struct ExtensionLanguageModel {
1588 extension: WasmExtension,
1589 model_info: LlmModelInfo,
1590 provider_id: LanguageModelProviderId,
1591 provider_name: LanguageModelProviderName,
1592 provider_info: LlmProviderInfo,
1593}
1594
1595impl LanguageModel for ExtensionLanguageModel {
1596 fn id(&self) -> LanguageModelId {
1597 LanguageModelId::from(self.model_info.id.clone())
1598 }
1599
1600 fn name(&self) -> LanguageModelName {
1601 LanguageModelName::from(self.model_info.name.clone())
1602 }
1603
1604 fn provider_id(&self) -> LanguageModelProviderId {
1605 self.provider_id.clone()
1606 }
1607
1608 fn provider_name(&self) -> LanguageModelProviderName {
1609 self.provider_name.clone()
1610 }
1611
1612 fn telemetry_id(&self) -> String {
1613 format!("extension-{}", self.model_info.id)
1614 }
1615
1616 fn supports_images(&self) -> bool {
1617 self.model_info.capabilities.supports_images
1618 }
1619
1620 fn supports_tools(&self) -> bool {
1621 self.model_info.capabilities.supports_tools
1622 }
1623
1624 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1625 match choice {
1626 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1627 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1628 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1629 }
1630 }
1631
1632 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1633 match self.model_info.capabilities.tool_input_format {
1634 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1635 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1636 }
1637 }
1638
1639 fn max_token_count(&self) -> u64 {
1640 self.model_info.max_token_count
1641 }
1642
1643 fn max_output_tokens(&self) -> Option<u64> {
1644 self.model_info.max_output_tokens
1645 }
1646
1647 fn count_tokens(
1648 &self,
1649 request: LanguageModelRequest,
1650 cx: &App,
1651 ) -> BoxFuture<'static, Result<u64>> {
1652 let extension = self.extension.clone();
1653 let provider_id = self.provider_info.id.clone();
1654 let model_id = self.model_info.id.clone();
1655
1656 let wit_request = convert_request_to_wit(request);
1657
1658 cx.background_spawn(async move {
1659 extension
1660 .call({
1661 let provider_id = provider_id.clone();
1662 let model_id = model_id.clone();
1663 let wit_request = wit_request.clone();
1664 |ext, store| {
1665 async move {
1666 let count = ext
1667 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1668 .await?
1669 .map_err(|e| anyhow!("{}", e))?;
1670 Ok(count)
1671 }
1672 .boxed()
1673 }
1674 })
1675 .await?
1676 })
1677 .boxed()
1678 }
1679
1680 fn stream_completion(
1681 &self,
1682 request: LanguageModelRequest,
1683 _cx: &AsyncApp,
1684 ) -> BoxFuture<
1685 'static,
1686 Result<
1687 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1688 LanguageModelCompletionError,
1689 >,
1690 > {
1691 let extension = self.extension.clone();
1692 let provider_id = self.provider_info.id.clone();
1693 let model_id = self.model_info.id.clone();
1694
1695 let wit_request = convert_request_to_wit(request);
1696
1697 async move {
1698 // Start the stream
1699 let stream_id_result = extension
1700 .call({
1701 let provider_id = provider_id.clone();
1702 let model_id = model_id.clone();
1703 let wit_request = wit_request.clone();
1704 |ext, store| {
1705 async move {
1706 let id = ext
1707 .call_llm_stream_completion_start(
1708 store,
1709 &provider_id,
1710 &model_id,
1711 &wit_request,
1712 )
1713 .await?
1714 .map_err(|e| anyhow!("{}", e))?;
1715 Ok(id)
1716 }
1717 .boxed()
1718 }
1719 })
1720 .await;
1721
1722 let stream_id = stream_id_result
1723 .map_err(LanguageModelCompletionError::Other)?
1724 .map_err(LanguageModelCompletionError::Other)?;
1725
1726 // Create a stream that polls for events
1727 let stream = futures::stream::unfold(
1728 (extension.clone(), stream_id, false),
1729 move |(extension, stream_id, done)| async move {
1730 if done {
1731 return None;
1732 }
1733
1734 let result = extension
1735 .call({
1736 let stream_id = stream_id.clone();
1737 |ext, store| {
1738 async move {
1739 let event = ext
1740 .call_llm_stream_completion_next(store, &stream_id)
1741 .await?
1742 .map_err(|e| anyhow!("{}", e))?;
1743 Ok(event)
1744 }
1745 .boxed()
1746 }
1747 })
1748 .await
1749 .and_then(|inner| inner);
1750
1751 match result {
1752 Ok(Some(event)) => {
1753 let converted = convert_completion_event(event);
1754 let is_done =
1755 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1756 Some((converted, (extension, stream_id, is_done)))
1757 }
1758 Ok(None) => {
1759 // Stream complete, close it
1760 let _ = extension
1761 .call({
1762 let stream_id = stream_id.clone();
1763 |ext, store| {
1764 async move {
1765 ext.call_llm_stream_completion_close(store, &stream_id)
1766 .await?;
1767 Ok::<(), anyhow::Error>(())
1768 }
1769 .boxed()
1770 }
1771 })
1772 .await;
1773 None
1774 }
1775 Err(e) => Some((
1776 Err(LanguageModelCompletionError::Other(e)),
1777 (extension, stream_id, true),
1778 )),
1779 }
1780 },
1781 );
1782
1783 Ok(stream.boxed())
1784 }
1785 .boxed()
1786 }
1787
1788 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1789 // Extensions can implement this via llm_cache_configuration
1790 None
1791 }
1792}
1793
1794fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1795 use language_model::{MessageContent, Role};
1796
1797 let messages: Vec<LlmRequestMessage> = request
1798 .messages
1799 .into_iter()
1800 .map(|msg| {
1801 let role = match msg.role {
1802 Role::User => LlmMessageRole::User,
1803 Role::Assistant => LlmMessageRole::Assistant,
1804 Role::System => LlmMessageRole::System,
1805 };
1806
1807 let content: Vec<LlmMessageContent> = msg
1808 .content
1809 .into_iter()
1810 .map(|c| match c {
1811 MessageContent::Text(text) => LlmMessageContent::Text(text),
1812 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1813 source: image.source.to_string(),
1814 width: Some(image.size.width.0 as u32),
1815 height: Some(image.size.height.0 as u32),
1816 }),
1817 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1818 id: tool_use.id.to_string(),
1819 name: tool_use.name.to_string(),
1820 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1821 is_input_complete: tool_use.is_input_complete,
1822 thought_signature: tool_use.thought_signature,
1823 }),
1824 MessageContent::ToolResult(tool_result) => {
1825 let content = match tool_result.content {
1826 language_model::LanguageModelToolResultContent::Text(text) => {
1827 LlmToolResultContent::Text(text.to_string())
1828 }
1829 language_model::LanguageModelToolResultContent::Image(image) => {
1830 LlmToolResultContent::Image(LlmImageData {
1831 source: image.source.to_string(),
1832 width: Some(image.size.width.0 as u32),
1833 height: Some(image.size.height.0 as u32),
1834 })
1835 }
1836 };
1837 LlmMessageContent::ToolResult(LlmToolResult {
1838 tool_use_id: tool_result.tool_use_id.to_string(),
1839 tool_name: tool_result.tool_name.to_string(),
1840 is_error: tool_result.is_error,
1841 content,
1842 })
1843 }
1844 MessageContent::Thinking { text, signature } => {
1845 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1846 }
1847 MessageContent::RedactedThinking(data) => {
1848 LlmMessageContent::RedactedThinking(data)
1849 }
1850 })
1851 .collect();
1852
1853 LlmRequestMessage {
1854 role,
1855 content,
1856 cache: msg.cache,
1857 }
1858 })
1859 .collect();
1860
1861 let tools: Vec<LlmToolDefinition> = request
1862 .tools
1863 .into_iter()
1864 .map(|tool| LlmToolDefinition {
1865 name: tool.name,
1866 description: tool.description,
1867 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1868 })
1869 .collect();
1870
1871 let tool_choice = request.tool_choice.map(|tc| match tc {
1872 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1873 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1874 LanguageModelToolChoice::None => LlmToolChoice::None,
1875 });
1876
1877 LlmCompletionRequest {
1878 messages,
1879 tools,
1880 tool_choice,
1881 stop_sequences: request.stop,
1882 temperature: request.temperature,
1883 thinking_allowed: false,
1884 max_tokens: None,
1885 }
1886}
1887
1888fn convert_completion_event(
1889 event: LlmCompletionEvent,
1890) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1891 match event {
1892 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1893 message_id: String::new(),
1894 }),
1895 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1896 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1897 text: thinking.text,
1898 signature: thinking.signature,
1899 }),
1900 LlmCompletionEvent::RedactedThinking(data) => {
1901 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1902 }
1903 LlmCompletionEvent::ToolUse(tool_use) => {
1904 let raw_input = tool_use.input.clone();
1905 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1906 Ok(LanguageModelCompletionEvent::ToolUse(
1907 LanguageModelToolUse {
1908 id: LanguageModelToolUseId::from(tool_use.id),
1909 name: tool_use.name.into(),
1910 raw_input,
1911 input,
1912 is_input_complete: tool_use.is_input_complete,
1913 thought_signature: tool_use.thought_signature,
1914 },
1915 ))
1916 }
1917 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1918 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1919 id: LanguageModelToolUseId::from(error.id),
1920 tool_name: error.tool_name.into(),
1921 raw_input: error.raw_input.into(),
1922 json_parse_error: error.error,
1923 })
1924 }
1925 LlmCompletionEvent::Stop(reason) => {
1926 let stop_reason = match reason {
1927 LlmStopReason::EndTurn => StopReason::EndTurn,
1928 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1929 LlmStopReason::ToolUse => StopReason::ToolUse,
1930 LlmStopReason::Refusal => StopReason::Refusal,
1931 };
1932 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1933 }
1934 LlmCompletionEvent::Usage(usage) => {
1935 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1936 input_tokens: usage.input_tokens,
1937 output_tokens: usage.output_tokens,
1938 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1939 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1940 }))
1941 }
1942 LlmCompletionEvent::ReasoningDetails(json) => {
1943 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1944 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1945 ))
1946 }
1947 }
1948}