1use crate::ExtensionSettings;
2use crate::LEGACY_LLM_EXTENSION_IDS;
3use crate::wasm_host::WasmExtension;
4use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
5use collections::HashSet;
6
7use crate::wasm_host::wit::{
8 LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData,
9 LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage,
10 LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat,
11 LlmToolResult, LlmToolResultContent, LlmToolUse,
12};
13use anyhow::{Result, anyhow};
14use collections::HashMap;
15use credentials_provider::CredentialsProvider;
16use extension::{LanguageModelAuthConfig, OAuthConfig};
17use futures::future::BoxFuture;
18use futures::stream::BoxStream;
19use futures::{FutureExt, StreamExt};
20
21use gpui::{
22 AnyView, App, AsyncApp, ClipboardItem, DismissEvent, Entity, EventEmitter, FocusHandle,
23 Focusable, MouseDownEvent, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window,
24 WindowBounds, WindowOptions, point, prelude::*, px,
25};
26use language_model::tool_schema::LanguageModelToolSchemaFormat;
27use language_model::{
28 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
29 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
30 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
31 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
32 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, RateLimiter, StopReason,
33 TokenUsage,
34};
35use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
36use settings::Settings;
37use std::sync::Arc;
38use theme::ThemeSettings;
39use ui::{
40 ButtonLike, ButtonLink, Checkbox, ConfiguredApiCard, SpinnerLabel, ToggleState, Vector,
41 VectorName, prelude::*,
42};
43use ui_input::InputField;
44use util::ResultExt as _;
45use workspace::Workspace;
46use workspace::oauth_device_flow_modal::{
47 OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
48};
49
50/// An extension-based language model provider.
51pub struct ExtensionLanguageModelProvider {
52 pub extension: WasmExtension,
53 pub provider_info: LlmProviderInfo,
54 icon_path: Option<SharedString>,
55 auth_config: Option<LanguageModelAuthConfig>,
56 state: Entity<ExtensionLlmProviderState>,
57}
58
59pub struct ExtensionLlmProviderState {
60 is_authenticated: bool,
61 available_models: Vec<LlmModelInfo>,
62 /// Cache configurations for each model, keyed by model ID.
63 cache_configs: HashMap<String, LlmCacheConfiguration>,
64 /// Set of env var names that are allowed to be read for this provider.
65 allowed_env_vars: HashSet<String>,
66 /// If authenticated via env var, which one was used.
67 env_var_name_used: Option<String>,
68}
69
70impl EventEmitter<()> for ExtensionLlmProviderState {}
71
72impl ExtensionLanguageModelProvider {
73 pub fn new(
74 extension: WasmExtension,
75 provider_info: LlmProviderInfo,
76 models: Vec<LlmModelInfo>,
77 cache_configs: HashMap<String, LlmCacheConfiguration>,
78 is_authenticated: bool,
79 icon_path: Option<SharedString>,
80 auth_config: Option<LanguageModelAuthConfig>,
81 cx: &mut App,
82 ) -> Self {
83 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
84
85 // Build set of allowed env vars for this provider
86 let settings = ExtensionSettings::get_global(cx);
87 let is_legacy_extension =
88 LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
89
90 let mut allowed_env_vars = HashSet::default();
91 if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
92 for env_var_name in env_vars {
93 let key = format!("{}:{}", provider_id_string, env_var_name);
94 // For legacy extensions, auto-allow if env var is set (migration will persist this)
95 let env_var_is_set = std::env::var(env_var_name)
96 .map(|v| !v.is_empty())
97 .unwrap_or(false);
98 if settings.allowed_env_var_providers.contains(key.as_str())
99 || (is_legacy_extension && env_var_is_set)
100 {
101 allowed_env_vars.insert(env_var_name.clone());
102 }
103 }
104 }
105
106 // Check if any allowed env var is set
107 let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
108 if let Ok(value) = std::env::var(env_var_name) {
109 if !value.is_empty() {
110 return Some(env_var_name.clone());
111 }
112 }
113 None
114 });
115
116 let is_authenticated = if env_var_name_used.is_some() {
117 true
118 } else {
119 is_authenticated
120 };
121
122 let state = cx.new(|_| ExtensionLlmProviderState {
123 is_authenticated,
124 available_models: models,
125 cache_configs,
126 allowed_env_vars,
127 env_var_name_used,
128 });
129
130 Self {
131 extension,
132 provider_info,
133 icon_path,
134 auth_config,
135 state,
136 }
137 }
138
139 fn provider_id_string(&self) -> String {
140 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
141 }
142
143 /// The credential key used for storing the API key in the system keychain.
144 fn credential_key(&self) -> String {
145 format!("extension-llm-{}", self.provider_id_string())
146 }
147
148 fn create_model(
149 &self,
150 model_info: &LlmModelInfo,
151 cache_configs: &HashMap<String, LlmCacheConfiguration>,
152 ) -> Arc<dyn LanguageModel> {
153 let cache_config =
154 cache_configs
155 .get(&model_info.id)
156 .map(|config| LanguageModelCacheConfiguration {
157 max_cache_anchors: config.max_cache_anchors as usize,
158 should_speculate: false,
159 min_total_token: config.min_total_token_count,
160 });
161
162 Arc::new(ExtensionLanguageModel {
163 extension: self.extension.clone(),
164 model_info: model_info.clone(),
165 provider_id: self.id(),
166 provider_name: self.name(),
167 provider_info: self.provider_info.clone(),
168 request_limiter: RateLimiter::new(4),
169 cache_config,
170 })
171 }
172}
173
174impl LanguageModelProvider for ExtensionLanguageModelProvider {
175 fn id(&self) -> LanguageModelProviderId {
176 LanguageModelProviderId::from(self.provider_id_string())
177 }
178
179 fn name(&self) -> LanguageModelProviderName {
180 LanguageModelProviderName::from(self.provider_info.name.clone())
181 }
182
183 fn icon(&self) -> ui::IconName {
184 ui::IconName::ZedAssistant
185 }
186
187 fn icon_path(&self) -> Option<SharedString> {
188 self.icon_path.clone()
189 }
190
191 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
192 let state = self.state.read(cx);
193 state
194 .available_models
195 .iter()
196 .find(|m| m.is_default)
197 .or_else(|| state.available_models.first())
198 .map(|model_info| self.create_model(model_info, &state.cache_configs))
199 }
200
201 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
202 let state = self.state.read(cx);
203 state
204 .available_models
205 .iter()
206 .find(|m| m.is_default_fast)
207 .map(|model_info| self.create_model(model_info, &state.cache_configs))
208 }
209
210 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
211 let state = self.state.read(cx);
212 state
213 .available_models
214 .iter()
215 .map(|model_info| self.create_model(model_info, &state.cache_configs))
216 .collect()
217 }
218
219 fn is_authenticated(&self, cx: &App) -> bool {
220 // First check cached state
221 if self.state.read(cx).is_authenticated {
222 return true;
223 }
224
225 // Also check env var dynamically (in case settings changed after provider creation)
226 if let Some(ref auth_config) = self.auth_config {
227 if let Some(ref env_vars) = auth_config.env_vars {
228 let provider_id_string = self.provider_id_string();
229 let settings = ExtensionSettings::get_global(cx);
230 let is_legacy_extension =
231 LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
232
233 for env_var_name in env_vars {
234 let key = format!("{}:{}", provider_id_string, env_var_name);
235 // For legacy extensions, auto-allow if env var is set
236 let env_var_is_set = std::env::var(env_var_name)
237 .map(|v| !v.is_empty())
238 .unwrap_or(false);
239 if settings.allowed_env_var_providers.contains(key.as_str())
240 || (is_legacy_extension && env_var_is_set)
241 {
242 if let Ok(value) = std::env::var(env_var_name) {
243 if !value.is_empty() {
244 return true;
245 }
246 }
247 }
248 }
249 }
250 }
251
252 false
253 }
254
255 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
256 // Check if already authenticated via is_authenticated
257 if self.is_authenticated(cx) {
258 return Task::ready(Ok(()));
259 }
260
261 // Not authenticated - return error indicating credentials not found
262 Task::ready(Err(AuthenticateError::CredentialsNotFound))
263 }
264
265 fn configuration_view(
266 &self,
267 target_agent: ConfigurationViewTargetAgent,
268 window: &mut Window,
269 cx: &mut App,
270 ) -> AnyView {
271 let credential_key = self.credential_key();
272 let extension = self.extension.clone();
273 let extension_provider_id = self.provider_info.id.clone();
274 let full_provider_id = self.provider_id_string();
275 let state = self.state.clone();
276 let auth_config = self.auth_config.clone();
277
278 let icon_path = self.icon_path.clone();
279 cx.new(|cx| {
280 ExtensionProviderConfigurationView::new(
281 credential_key,
282 extension,
283 extension_provider_id,
284 full_provider_id,
285 auth_config,
286 state,
287 icon_path,
288 target_agent,
289 window,
290 cx,
291 )
292 })
293 .into()
294 }
295
296 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
297 let extension = self.extension.clone();
298 let provider_id = self.provider_info.id.clone();
299 let state = self.state.clone();
300 let credential_key = self.credential_key();
301
302 let credentials_provider = <dyn CredentialsProvider>::global(cx);
303
304 cx.spawn(async move |cx| {
305 // Delete from system keychain
306 credentials_provider
307 .delete_credentials(&credential_key, cx)
308 .await
309 .log_err();
310
311 // Call extension's reset_credentials
312 let result = extension
313 .call(|extension, store| {
314 async move {
315 extension
316 .call_llm_provider_reset_credentials(store, &provider_id)
317 .await
318 }
319 .boxed()
320 })
321 .await;
322
323 // Update state
324 cx.update(|cx| {
325 state.update(cx, |state, _| {
326 state.is_authenticated = false;
327 });
328 })?;
329
330 match result {
331 Ok(Ok(Ok(()))) => Ok(()),
332 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
333 Ok(Err(e)) => Err(e),
334 Err(e) => Err(e),
335 }
336 })
337 }
338}
339
340impl LanguageModelProviderState for ExtensionLanguageModelProvider {
341 type ObservableEntity = ExtensionLlmProviderState;
342
343 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
344 Some(self.state.clone())
345 }
346
347 fn subscribe<T: 'static>(
348 &self,
349 cx: &mut Context<T>,
350 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
351 ) -> Option<Subscription> {
352 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
353 }
354}
355
356/// Configuration view for extension-based LLM providers.
357struct ExtensionProviderConfigurationView {
358 credential_key: String,
359 extension: WasmExtension,
360 extension_provider_id: String,
361 full_provider_id: String,
362 auth_config: Option<LanguageModelAuthConfig>,
363 state: Entity<ExtensionLlmProviderState>,
364 settings_markdown: Option<Entity<Markdown>>,
365 api_key_editor: Entity<InputField>,
366 loading_settings: bool,
367 loading_credentials: bool,
368 oauth_in_progress: bool,
369 oauth_error: Option<String>,
370 icon_path: Option<SharedString>,
371 target_agent: ConfigurationViewTargetAgent,
372 _subscriptions: Vec<Subscription>,
373}
374
375impl ExtensionProviderConfigurationView {
376 fn new(
377 credential_key: String,
378 extension: WasmExtension,
379 extension_provider_id: String,
380 full_provider_id: String,
381 auth_config: Option<LanguageModelAuthConfig>,
382 state: Entity<ExtensionLlmProviderState>,
383 icon_path: Option<SharedString>,
384 target_agent: ConfigurationViewTargetAgent,
385 window: &mut Window,
386 cx: &mut Context<Self>,
387 ) -> Self {
388 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
389 cx.notify();
390 });
391
392 let credential_label = auth_config
393 .as_ref()
394 .and_then(|c| c.credential_label.clone())
395 .unwrap_or_else(|| "API Key".to_string());
396
397 let api_key_editor = cx.new(|cx| {
398 InputField::new(window, cx, "Enter API key and hit enter").label(credential_label)
399 });
400
401 let mut this = Self {
402 credential_key,
403 extension,
404 extension_provider_id,
405 full_provider_id,
406 auth_config,
407 state,
408 settings_markdown: None,
409 api_key_editor,
410 loading_settings: true,
411 loading_credentials: true,
412 oauth_in_progress: false,
413 oauth_error: None,
414 icon_path,
415 target_agent,
416 _subscriptions: vec![state_subscription],
417 };
418
419 this.load_settings_text(cx);
420 this.load_credentials(cx);
421 this
422 }
423
424 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
425 let extension = self.extension.clone();
426 let provider_id = self.extension_provider_id.clone();
427
428 cx.spawn(async move |this, cx| {
429 let result = extension
430 .call({
431 let provider_id = provider_id.clone();
432 |ext, store| {
433 async move {
434 ext.call_llm_provider_settings_markdown(store, &provider_id)
435 .await
436 }
437 .boxed()
438 }
439 })
440 .await;
441
442 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
443
444 this.update(cx, |this, cx| {
445 this.loading_settings = false;
446 if let Some(text) = settings_text {
447 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
448 this.settings_markdown = Some(markdown);
449 }
450 cx.notify();
451 })
452 .log_err();
453 })
454 .detach();
455 }
456
457 fn load_credentials(&mut self, cx: &mut Context<Self>) {
458 let credential_key = self.credential_key.clone();
459 let credentials_provider = <dyn CredentialsProvider>::global(cx);
460 let state = self.state.clone();
461
462 // Check if we should use env var (already set in state during provider construction)
463 let using_env_var = self.state.read(cx).env_var_name_used.is_some();
464
465 cx.spawn(async move |this, cx| {
466 // If using env var, we're already authenticated
467 if using_env_var {
468 this.update(cx, |this, cx| {
469 this.loading_credentials = false;
470 cx.notify();
471 })
472 .log_err();
473 return;
474 }
475
476 let credentials = credentials_provider
477 .read_credentials(&credential_key, cx)
478 .await
479 .log_err()
480 .flatten();
481
482 let has_credentials = credentials.is_some();
483
484 // Update authentication state based on stored credentials
485 cx.update(|cx| {
486 state.update(cx, |state, cx| {
487 state.is_authenticated = has_credentials;
488 cx.notify();
489 });
490 })
491 .log_err();
492
493 this.update(cx, |this, cx| {
494 this.loading_credentials = false;
495 cx.notify();
496 })
497 .log_err();
498 })
499 .detach();
500 }
501
502 fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
503 let full_provider_id = self.full_provider_id.clone();
504 let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
505
506 let state = self.state.clone();
507 let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
508
509 // Update settings file
510 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
511 move |settings, _| {
512 let allowed = settings
513 .extension
514 .allowed_env_var_providers
515 .get_or_insert_with(Vec::new);
516
517 if currently_allowed {
518 allowed.retain(|id| id.as_ref() != settings_key.as_ref());
519 } else {
520 if !allowed
521 .iter()
522 .any(|id| id.as_ref() == settings_key.as_ref())
523 {
524 allowed.push(settings_key.clone());
525 }
526 }
527 }
528 });
529
530 // Update local state
531 let new_allowed = !currently_allowed;
532
533 state.update(cx, |state, cx| {
534 if new_allowed {
535 state.allowed_env_vars.insert(env_var_name.clone());
536 // Check if this env var is set and update env_var_name_used
537 if let Ok(value) = std::env::var(&env_var_name) {
538 if !value.is_empty() && state.env_var_name_used.is_none() {
539 state.env_var_name_used = Some(env_var_name.clone());
540 state.is_authenticated = true;
541 }
542 }
543 } else {
544 state.allowed_env_vars.remove(&env_var_name);
545 // If this was the env var being used, clear it and find another
546 if state.env_var_name_used.as_ref() == Some(&env_var_name) {
547 state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
548 if let Ok(value) = std::env::var(var) {
549 if !value.is_empty() {
550 return Some(var.clone());
551 }
552 }
553 None
554 });
555 if state.env_var_name_used.is_none() {
556 // No env var auth available, need to check keychain
557 state.is_authenticated = false;
558 }
559 }
560 }
561 cx.notify();
562 });
563
564 // If all env vars are being disabled, reload credentials from keychain
565 if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
566 self.reload_keychain_credentials(cx);
567 }
568
569 cx.notify();
570 }
571
572 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
573 let credential_key = self.credential_key.clone();
574 let credentials_provider = <dyn CredentialsProvider>::global(cx);
575 let state = self.state.clone();
576
577 cx.spawn(async move |_this, cx| {
578 let credentials = credentials_provider
579 .read_credentials(&credential_key, cx)
580 .await
581 .log_err()
582 .flatten();
583
584 let has_credentials = credentials.is_some();
585
586 cx.update(|cx| {
587 state.update(cx, |state, cx| {
588 state.is_authenticated = has_credentials;
589 cx.notify();
590 });
591 })
592 .log_err();
593 })
594 .detach();
595 }
596
597 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
598 let api_key = self.api_key_editor.read(cx).text(cx);
599 if api_key.is_empty() {
600 return;
601 }
602
603 // Clear the editor
604 self.api_key_editor
605 .update(cx, |input, cx| input.clear(window, cx));
606
607 let credential_key = self.credential_key.clone();
608 let credentials_provider = <dyn CredentialsProvider>::global(cx);
609 let state = self.state.clone();
610
611 cx.spawn(async move |_this, cx| {
612 // Store in system keychain
613 credentials_provider
614 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
615 .await
616 .log_err();
617
618 // Update state to authenticated
619 cx.update(|cx| {
620 state.update(cx, |state, cx| {
621 state.is_authenticated = true;
622 cx.notify();
623 });
624 })
625 .log_err();
626 })
627 .detach();
628 }
629
630 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
631 // Clear the editor
632 self.api_key_editor
633 .update(cx, |input, cx| input.clear(window, cx));
634
635 let credential_key = self.credential_key.clone();
636 let credentials_provider = <dyn CredentialsProvider>::global(cx);
637 let state = self.state.clone();
638
639 cx.spawn(async move |_this, cx| {
640 // Delete from system keychain
641 credentials_provider
642 .delete_credentials(&credential_key, cx)
643 .await
644 .log_err();
645
646 // Update state to unauthenticated
647 cx.update(|cx| {
648 state.update(cx, |state, cx| {
649 state.is_authenticated = false;
650 cx.notify();
651 });
652 })
653 .log_err();
654 })
655 .detach();
656 }
657
658 fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
659 if self.oauth_in_progress {
660 return;
661 }
662
663 self.oauth_in_progress = true;
664 self.oauth_error = None;
665 cx.notify();
666
667 let extension = self.extension.clone();
668 let provider_id = self.extension_provider_id.clone();
669 let state = self.state.clone();
670 let icon_path = self.icon_path.clone();
671 let this_handle = cx.weak_entity();
672 let use_popup_window = self.is_edit_prediction_mode();
673
674 // Get current window bounds for positioning popup
675 let current_window_center = window.bounds().center();
676
677 // For workspace modal mode, find the workspace window
678 let workspace_window = if !use_popup_window {
679 log::info!("OAuth: Looking for workspace window");
680 let ws = window.window_handle().downcast::<Workspace>().or_else(|| {
681 log::info!("OAuth: Current window is not a workspace, searching other windows");
682 cx.windows()
683 .into_iter()
684 .find_map(|window_handle| window_handle.downcast::<Workspace>())
685 });
686
687 if ws.is_none() {
688 log::error!("OAuth: Could not find any workspace window");
689 self.oauth_in_progress = false;
690 self.oauth_error =
691 Some("Could not access workspace to show sign-in modal".to_string());
692 cx.notify();
693 return;
694 }
695 ws
696 } else {
697 None
698 };
699
700 log::info!(
701 "OAuth: Using {} mode",
702 if use_popup_window {
703 "popup window"
704 } else {
705 "workspace modal"
706 }
707 );
708 let state = state.downgrade();
709 cx.spawn(async move |_this, cx| {
710 // Step 1: Start device flow - get prompt info from extension
711 let start_result = extension
712 .call({
713 let provider_id = provider_id.clone();
714 |ext, store| {
715 async move {
716 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
717 .await
718 }
719 .boxed()
720 }
721 })
722 .await;
723
724 log::info!(
725 "OAuth: Device flow start result: {:?}",
726 start_result.is_ok()
727 );
728 let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
729 Ok(Ok(Ok(info))) => {
730 log::info!(
731 "OAuth: Got device flow prompt info, user_code: {}",
732 info.user_code
733 );
734 info
735 }
736 Ok(Ok(Err(e))) => {
737 log::error!("OAuth: Device flow start failed: {}", e);
738 this_handle
739 .update(cx, |this, cx| {
740 this.oauth_in_progress = false;
741 this.oauth_error = Some(e);
742 cx.notify();
743 })
744 .log_err();
745 return;
746 }
747 Ok(Err(e)) | Err(e) => {
748 log::error!("OAuth: Device flow start error: {}", e);
749 this_handle
750 .update(cx, |this, cx| {
751 this.oauth_in_progress = false;
752 this.oauth_error = Some(e.to_string());
753 cx.notify();
754 })
755 .log_err();
756 return;
757 }
758 };
759
760 // Step 2: Create state entity and show the modal/window
761 let modal_config = OAuthDeviceFlowModalConfig {
762 user_code: prompt_info.user_code,
763 verification_url: prompt_info.verification_url,
764 headline: prompt_info.headline,
765 description: prompt_info.description,
766 connect_button_label: prompt_info.connect_button_label,
767 success_headline: prompt_info.success_headline,
768 success_message: prompt_info.success_message,
769 icon_path,
770 };
771
772 let flow_state: Option<Entity<OAuthDeviceFlowState>> = if use_popup_window {
773 // Open a popup window like Copilot does
774 log::info!("OAuth: Opening popup window");
775 cx.update(|cx| {
776 let height = px(450.);
777 let width = px(350.);
778 let window_bounds = WindowBounds::Windowed(gpui::bounds(
779 current_window_center - point(height / 2.0, width / 2.0),
780 gpui::size(height, width),
781 ));
782
783 let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config.clone()));
784 let flow_state_for_window = flow_state.clone();
785
786 cx.open_window(
787 WindowOptions {
788 kind: gpui::WindowKind::PopUp,
789 window_bounds: Some(window_bounds),
790 is_resizable: false,
791 is_movable: true,
792 titlebar: Some(gpui::TitlebarOptions {
793 appears_transparent: true,
794 ..Default::default()
795 }),
796 ..Default::default()
797 },
798 |window, cx| {
799 cx.new(|cx| {
800 OAuthCodeVerificationWindow::new(
801 modal_config,
802 flow_state_for_window,
803 window,
804 cx,
805 )
806 })
807 },
808 )
809 .log_err();
810
811 Some(flow_state)
812 })
813 .ok()
814 .flatten()
815 } else {
816 // Use workspace modal
817 log::info!("OAuth: Attempting to show modal in workspace window");
818 workspace_window.as_ref().and_then(|ws| {
819 ws.update(cx, |workspace, window, cx| {
820 log::info!("OAuth: Inside workspace.update, creating modal");
821 window.activate_window();
822 let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
823 let flow_state_clone = flow_state.clone();
824 workspace.toggle_modal(window, cx, |_window, cx| {
825 log::info!("OAuth: Inside toggle_modal callback");
826 OAuthDeviceFlowModal::new(flow_state_clone, cx)
827 });
828 flow_state
829 })
830 .ok()
831 })
832 };
833
834 log::info!("OAuth: flow_state created: {:?}", flow_state.is_some());
835 let Some(flow_state) = flow_state else {
836 log::error!("OAuth: Failed to show sign-in modal/window");
837 this_handle
838 .update(cx, |this, cx| {
839 this.oauth_in_progress = false;
840 this.oauth_error = Some("Failed to show sign-in modal".to_string());
841 cx.notify();
842 })
843 .log_err();
844 return;
845 };
846 log::info!("OAuth: Modal/window shown successfully, starting poll");
847
848 // Step 3: Poll for authentication completion
849 let poll_result = extension
850 .call({
851 let provider_id = provider_id.clone();
852 |ext, store| {
853 async move {
854 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
855 .await
856 }
857 .boxed()
858 }
859 })
860 .await;
861
862 match poll_result {
863 Ok(Ok(Ok(()))) => {
864 // After successful auth, refresh the models list
865 let models_result = extension
866 .call({
867 let provider_id = provider_id.clone();
868 |ext, store| {
869 async move {
870 ext.call_llm_provider_models(store, &provider_id).await
871 }
872 .boxed()
873 }
874 })
875 .await;
876
877 let new_models: Vec<LlmModelInfo> = match models_result {
878 Ok(Ok(Ok(models))) => models,
879 _ => Vec::new(),
880 };
881
882 state
883 .update(cx, |state, cx| {
884 state.is_authenticated = true;
885 state.available_models = new_models;
886 cx.notify();
887 })
888 .log_err();
889
890 // Update flow state to show success
891 flow_state
892 .update(cx, |state, cx| {
893 state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
894 })
895 .log_err();
896 }
897 Ok(Ok(Err(e))) => {
898 log::error!("Device flow poll failed: {}", e);
899 flow_state
900 .update(cx, |state, cx| {
901 state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
902 })
903 .log_err();
904 this_handle
905 .update(cx, |this, cx| {
906 this.oauth_error = Some(e);
907 cx.notify();
908 })
909 .log_err();
910 }
911 Ok(Err(e)) | Err(e) => {
912 log::error!("Device flow poll error: {}", e);
913 let error_string = e.to_string();
914 flow_state
915 .update(cx, |state, cx| {
916 state.set_status(
917 OAuthDeviceFlowStatus::Failed(error_string.clone()),
918 cx,
919 );
920 })
921 .log_err();
922 this_handle
923 .update(cx, |this, cx| {
924 this.oauth_error = Some(error_string);
925 cx.notify();
926 })
927 .log_err();
928 }
929 };
930
931 this_handle
932 .update(cx, |this, cx| {
933 this.oauth_in_progress = false;
934 cx.notify();
935 })
936 .log_err();
937 })
938 .detach();
939 }
940
941 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
942 self.state.read(cx).is_authenticated
943 }
944
945 fn has_oauth_config(&self) -> bool {
946 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
947 }
948
949 fn oauth_config(&self) -> Option<&OAuthConfig> {
950 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
951 }
952
953 fn has_api_key_config(&self) -> bool {
954 // API key is available if there's a credential_label or no oauth-only config
955 self.auth_config
956 .as_ref()
957 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
958 .unwrap_or(true)
959 }
960
961 fn is_edit_prediction_mode(&self) -> bool {
962 self.target_agent == ConfigurationViewTargetAgent::EditPrediction
963 }
964
965 fn render_for_edit_prediction(
966 &mut self,
967 _window: &mut Window,
968 cx: &mut Context<Self>,
969 ) -> impl IntoElement {
970 let is_loading = self.loading_settings || self.loading_credentials;
971 let is_authenticated = self.is_authenticated(cx);
972 let has_oauth = self.has_oauth_config();
973
974 // Helper to create the horizontal container layout matching Copilot
975 let container = |description: SharedString, action: AnyElement| {
976 h_flex()
977 .pt_2p5()
978 .w_full()
979 .justify_between()
980 .child(
981 v_flex()
982 .w_full()
983 .max_w_1_2()
984 .child(Label::new("Authenticate To Use"))
985 .child(
986 Label::new(description)
987 .color(Color::Muted)
988 .size(LabelSize::Small),
989 ),
990 )
991 .child(action)
992 };
993
994 // Get the description from OAuth config or use a default
995 let oauth_config = self.oauth_config();
996 let description: SharedString = oauth_config
997 .and_then(|c| c.sign_in_description.clone())
998 .unwrap_or_else(|| "Sign in to authenticate with this provider.".to_string())
999 .into();
1000
1001 if is_loading {
1002 return container(
1003 description,
1004 Button::new("loading", "Loading...")
1005 .style(ButtonStyle::Outlined)
1006 .disabled(true)
1007 .into_any_element(),
1008 )
1009 .into_any_element();
1010 }
1011
1012 // If authenticated, show the configured card
1013 if is_authenticated {
1014 let (status_label, button_label) = if has_oauth {
1015 ("Authorized", "Sign Out")
1016 } else {
1017 ("API key configured", "Reset Key")
1018 };
1019
1020 return ConfiguredApiCard::new(status_label)
1021 .button_label(button_label)
1022 .on_click(cx.listener(|this, _, window, cx| {
1023 this.reset_api_key(window, cx);
1024 }))
1025 .into_any_element();
1026 }
1027
1028 // Not authenticated - show sign in button
1029 if has_oauth {
1030 let button_label = oauth_config
1031 .and_then(|c| c.sign_in_button_label.clone())
1032 .unwrap_or_else(|| "Sign In".to_string());
1033 let button_icon = oauth_config
1034 .and_then(|c| c.sign_in_button_icon.as_ref())
1035 .and_then(|icon_name| match icon_name.as_str() {
1036 "github" => Some(ui::IconName::Github),
1037 _ => None,
1038 });
1039
1040 let oauth_in_progress = self.oauth_in_progress;
1041
1042 let mut button = Button::new("oauth-sign-in", button_label)
1043 .size(ButtonSize::Medium)
1044 .style(ButtonStyle::Outlined)
1045 .disabled(oauth_in_progress)
1046 .on_click(cx.listener(|this, _, window, cx| {
1047 this.start_oauth_sign_in(window, cx);
1048 }));
1049
1050 if let Some(icon) = button_icon {
1051 button = button
1052 .icon(icon)
1053 .icon_position(ui::IconPosition::Start)
1054 .icon_size(ui::IconSize::Small)
1055 .icon_color(Color::Muted);
1056 }
1057
1058 return container(description, button.into_any_element()).into_any_element();
1059 }
1060
1061 // Fallback for API key only providers - show a simple message
1062 container(
1063 description,
1064 Button::new("configure", "Configure")
1065 .size(ButtonSize::Medium)
1066 .style(ButtonStyle::Outlined)
1067 .disabled(true)
1068 .into_any_element(),
1069 )
1070 .into_any_element()
1071 }
1072}
1073
1074impl Render for ExtensionProviderConfigurationView {
1075 fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1076 if self.is_edit_prediction_mode() {
1077 return self
1078 .render_for_edit_prediction(window, cx)
1079 .into_any_element();
1080 }
1081
1082 let is_loading = self.loading_settings || self.loading_credentials;
1083 let is_authenticated = self.is_authenticated(cx);
1084 let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
1085 let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
1086 let has_oauth = self.has_oauth_config();
1087 let has_api_key = self.has_api_key_config();
1088
1089 if is_loading {
1090 return h_flex()
1091 .gap_2()
1092 .child(
1093 h_flex()
1094 .w_2()
1095 .child(SpinnerLabel::sand().size(LabelSize::Small)),
1096 )
1097 .child(LoadingLabel::new("Loading").size(LabelSize::Small))
1098 .into_any_element();
1099 }
1100
1101 let mut content = v_flex().size_full().gap_2();
1102
1103 if let Some(markdown) = &self.settings_markdown {
1104 content = content.text_sm().child(MarkdownElement::new(
1105 markdown.clone(),
1106 markdown_styles(window, cx),
1107 ));
1108 }
1109
1110 if let Some(auth_config) = &self.auth_config {
1111 if let Some(env_vars) = &auth_config.env_vars {
1112 for env_var_name in env_vars {
1113 let is_allowed = allowed_env_vars.contains(env_var_name);
1114 let checkbox_label =
1115 format!("Read API key from {} environment variable.", env_var_name);
1116 let env_var_for_click = env_var_name.clone();
1117
1118 content = content.child(
1119 Checkbox::new(
1120 SharedString::from(format!("env-var-{}", env_var_name)),
1121 if is_allowed {
1122 ToggleState::Selected
1123 } else {
1124 ToggleState::Unselected
1125 },
1126 )
1127 .label(checkbox_label)
1128 .on_click(cx.listener(
1129 move |this, _, _window, cx| {
1130 this.toggle_env_var_permission(env_var_for_click.clone(), cx);
1131 },
1132 )),
1133 );
1134 }
1135
1136 if let Some(used_var) = &env_var_name_used {
1137 content = content.child(
1138 ConfiguredApiCard::new(format!(
1139 "API key set in {} environment variable",
1140 used_var
1141 ))
1142 .tooltip_label(format!(
1143 "To reset this API key, unset the {} environment variable.",
1144 used_var
1145 ))
1146 .disabled(true),
1147 );
1148
1149 return content.into_any_element();
1150 }
1151 }
1152 }
1153
1154 if is_authenticated && env_var_name_used.is_none() {
1155 let (status_label, button_label) = if has_oauth && !has_api_key {
1156 ("Signed in", "Sign Out")
1157 } else {
1158 ("API key configured", "Reset Key")
1159 };
1160
1161 content = content.child(
1162 ConfiguredApiCard::new(status_label)
1163 .button_label(button_label)
1164 .on_click(cx.listener(|this, _, window, cx| {
1165 this.reset_api_key(window, cx);
1166 })),
1167 );
1168
1169 return content.into_any_element();
1170 }
1171
1172 // Not authenticated - show available auth options
1173 if env_var_name_used.is_none() {
1174 // Render OAuth sign-in button if configured
1175 if has_oauth {
1176 let oauth_config = self.oauth_config();
1177 let button_label = oauth_config
1178 .and_then(|c| c.sign_in_button_label.clone())
1179 .unwrap_or_else(|| "Sign In".to_string());
1180 let button_icon = oauth_config
1181 .and_then(|c| c.sign_in_button_icon.as_ref())
1182 .and_then(|icon_name| match icon_name.as_str() {
1183 "github" => Some(ui::IconName::Github),
1184 _ => None,
1185 });
1186
1187 let oauth_in_progress = self.oauth_in_progress;
1188
1189 let oauth_error = self.oauth_error.clone();
1190
1191 let mut button = Button::new("oauth-sign-in", button_label)
1192 .full_width()
1193 .style(ButtonStyle::Outlined)
1194 .disabled(oauth_in_progress)
1195 .on_click(cx.listener(|this, _, window, cx| {
1196 this.start_oauth_sign_in(window, cx);
1197 }));
1198 if let Some(icon) = button_icon {
1199 button = button
1200 .icon(icon)
1201 .icon_position(IconPosition::Start)
1202 .icon_size(IconSize::Small)
1203 .icon_color(Color::Muted);
1204 }
1205
1206 content = content.child(
1207 v_flex()
1208 .gap_2()
1209 .child(button)
1210 .when(oauth_in_progress, |this| {
1211 this.child(
1212 Label::new("Sign-in in progress...")
1213 .size(LabelSize::Small)
1214 .color(Color::Muted),
1215 )
1216 })
1217 .when_some(oauth_error, |this, error| {
1218 this.child(
1219 v_flex()
1220 .gap_1()
1221 .child(
1222 h_flex()
1223 .gap_2()
1224 .child(
1225 Icon::new(IconName::Warning)
1226 .color(Color::Error)
1227 .size(IconSize::Small),
1228 )
1229 .child(
1230 Label::new("Authentication failed")
1231 .color(Color::Error)
1232 .size(LabelSize::Small),
1233 ),
1234 )
1235 .child(
1236 div().pl_6().child(
1237 Label::new(error)
1238 .color(Color::Error)
1239 .size(LabelSize::Small),
1240 ),
1241 ),
1242 )
1243 }),
1244 );
1245 }
1246
1247 // Render API key input if configured (and we have both options, show a separator)
1248 if has_api_key {
1249 if has_oauth {
1250 content = content.child(
1251 h_flex()
1252 .gap_2()
1253 .items_center()
1254 .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant))
1255 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1256 .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant)),
1257 );
1258 }
1259
1260 content = content.child(
1261 div()
1262 .on_action(cx.listener(Self::save_api_key))
1263 .child(self.api_key_editor.clone()),
1264 );
1265 }
1266 }
1267
1268 if self.extension_provider_id == "openai" {
1269 content = content.child(
1270 h_flex()
1271 .gap_1()
1272 .child(
1273 Icon::new(IconName::Info)
1274 .size(IconSize::XSmall)
1275 .color(Color::Muted),
1276 )
1277 .child(
1278 Label::new("Zed also supports OpenAI-compatible models.")
1279 .size(LabelSize::Small)
1280 .color(Color::Muted),
1281 )
1282 .child(
1283 ButtonLink::new(
1284 "Learn More",
1285 "https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers",
1286 )
1287 .label_size(LabelSize::Small),
1288 ),
1289 );
1290 }
1291
1292 content.into_any_element()
1293 }
1294}
1295
1296impl Focusable for ExtensionProviderConfigurationView {
1297 fn focus_handle(&self, cx: &App) -> FocusHandle {
1298 self.api_key_editor.read(cx).focus_handle(cx)
1299 }
1300}
1301
1302/// A popup window for OAuth device flow, similar to CopilotCodeVerification.
1303/// This is used when in edit prediction mode to avoid moving the settings panel behind.
1304pub struct OAuthCodeVerificationWindow {
1305 config: OAuthDeviceFlowModalConfig,
1306 status: OAuthDeviceFlowStatus,
1307 connect_clicked: bool,
1308 focus_handle: FocusHandle,
1309 _subscription: Option<Subscription>,
1310}
1311
1312impl Focusable for OAuthCodeVerificationWindow {
1313 fn focus_handle(&self, _: &App) -> FocusHandle {
1314 self.focus_handle.clone()
1315 }
1316}
1317
1318impl EventEmitter<DismissEvent> for OAuthCodeVerificationWindow {}
1319
1320impl OAuthCodeVerificationWindow {
1321 pub fn new(
1322 config: OAuthDeviceFlowModalConfig,
1323 state: Entity<OAuthDeviceFlowState>,
1324 window: &mut Window,
1325 cx: &mut Context<Self>,
1326 ) -> Self {
1327 window.on_window_should_close(cx, |window, cx| {
1328 if let Some(this) = window.root::<OAuthCodeVerificationWindow>().flatten() {
1329 this.update(cx, |_, cx| {
1330 cx.emit(DismissEvent);
1331 });
1332 }
1333 true
1334 });
1335 cx.subscribe_in(
1336 &cx.entity(),
1337 window,
1338 |_, _, _: &DismissEvent, window, _cx| {
1339 window.remove_window();
1340 },
1341 )
1342 .detach();
1343
1344 let subscription = cx.observe(&state, |this, state, cx| {
1345 let status = state.read(cx).status.clone();
1346 this.status = status;
1347 cx.notify();
1348 });
1349
1350 Self {
1351 config,
1352 status: state.read(cx).status.clone(),
1353 connect_clicked: false,
1354 focus_handle: cx.focus_handle(),
1355 _subscription: Some(subscription),
1356 }
1357 }
1358
1359 fn render_icon(&self, cx: &mut Context<Self>) -> impl IntoElement {
1360 let icon_color = Color::Custom(cx.theme().colors().icon);
1361 let icon_size = rems(2.5);
1362 let plus_size = rems(0.875);
1363 let plus_color = cx.theme().colors().icon.opacity(0.5);
1364
1365 if let Some(icon_path) = &self.config.icon_path {
1366 h_flex()
1367 .gap_2()
1368 .items_center()
1369 .child(
1370 Icon::from_external_svg(icon_path.clone())
1371 .size(IconSize::Custom(icon_size))
1372 .color(icon_color),
1373 )
1374 .child(
1375 gpui::svg()
1376 .size(plus_size)
1377 .path("icons/plus.svg")
1378 .text_color(plus_color),
1379 )
1380 .child(Vector::new(VectorName::ZedLogo, icon_size, icon_size).color(icon_color))
1381 .into_any_element()
1382 } else {
1383 Vector::new(VectorName::ZedLogo, icon_size, icon_size)
1384 .color(icon_color)
1385 .into_any_element()
1386 }
1387 }
1388
1389 fn render_device_code(&self, cx: &mut Context<Self>) -> impl IntoElement {
1390 let user_code = self.config.user_code.clone();
1391 let copied = cx
1392 .read_from_clipboard()
1393 .map(|item| item.text().as_ref() == Some(&user_code))
1394 .unwrap_or(false);
1395 let user_code_for_click = user_code.clone();
1396
1397 ButtonLike::new("copy-button")
1398 .full_width()
1399 .style(ButtonStyle::Tinted(ui::TintColor::Accent))
1400 .size(ButtonSize::Medium)
1401 .child(
1402 h_flex()
1403 .w_full()
1404 .p_1()
1405 .justify_between()
1406 .child(Label::new(user_code))
1407 .child(Label::new(if copied { "Copied!" } else { "Copy" })),
1408 )
1409 .on_click(move |_, window, cx| {
1410 cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone()));
1411 window.refresh();
1412 })
1413 }
1414
1415 fn render_prompting_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1416 let connect_button_label: String = if self.connect_clicked {
1417 "Waiting for connection…".to_string()
1418 } else {
1419 self.config.connect_button_label.clone()
1420 };
1421 let verification_url = self.config.verification_url.clone();
1422
1423 v_flex()
1424 .flex_1()
1425 .gap_2p5()
1426 .items_center()
1427 .text_center()
1428 .child(Headline::new(self.config.headline.clone()).size(HeadlineSize::Large))
1429 .child(Label::new(self.config.description.clone()).color(Color::Muted))
1430 .child(self.render_device_code(cx))
1431 .child(
1432 Label::new("Paste this code after clicking the button below.").color(Color::Muted),
1433 )
1434 .child(
1435 v_flex()
1436 .w_full()
1437 .gap_1()
1438 .child(
1439 Button::new("connect-button", connect_button_label)
1440 .full_width()
1441 .style(ButtonStyle::Outlined)
1442 .size(ButtonSize::Medium)
1443 .on_click(cx.listener(move |this, _, _window, cx| {
1444 cx.open_url(&verification_url);
1445 this.connect_clicked = true;
1446 })),
1447 )
1448 .child(
1449 Button::new("cancel-button", "Cancel")
1450 .full_width()
1451 .size(ButtonSize::Medium)
1452 .on_click(cx.listener(|_, _, _, cx| {
1453 cx.emit(DismissEvent);
1454 })),
1455 ),
1456 )
1457 }
1458
1459 fn render_authorized_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1460 v_flex()
1461 .gap_2()
1462 .text_center()
1463 .justify_center()
1464 .child(Headline::new(self.config.success_headline.clone()).size(HeadlineSize::Large))
1465 .child(Label::new(self.config.success_message.clone()).color(Color::Muted))
1466 .child(
1467 Button::new("done-button", "Done")
1468 .full_width()
1469 .style(ButtonStyle::Outlined)
1470 .size(ButtonSize::Medium)
1471 .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1472 )
1473 }
1474
1475 fn render_failed_modal(&self, error: &str, cx: &mut Context<Self>) -> impl IntoElement {
1476 v_flex()
1477 .gap_2()
1478 .text_center()
1479 .justify_center()
1480 .child(Headline::new("Authorization Failed").size(HeadlineSize::Large))
1481 .child(Label::new(error.to_string()).color(Color::Error))
1482 .child(
1483 Button::new("close-button", "Close")
1484 .full_width()
1485 .size(ButtonSize::Medium)
1486 .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1487 )
1488 }
1489}
1490
1491impl Render for OAuthCodeVerificationWindow {
1492 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1493 let prompt = match &self.status {
1494 OAuthDeviceFlowStatus::Prompting | OAuthDeviceFlowStatus::WaitingForAuthorization => {
1495 self.render_prompting_modal(cx).into_any_element()
1496 }
1497 OAuthDeviceFlowStatus::Authorized => {
1498 self.render_authorized_modal(cx).into_any_element()
1499 }
1500 OAuthDeviceFlowStatus::Failed(error) => {
1501 self.render_failed_modal(error, cx).into_any_element()
1502 }
1503 };
1504
1505 v_flex()
1506 .id("oauth_code_verification")
1507 .track_focus(&self.focus_handle(cx))
1508 .size_full()
1509 .px_4()
1510 .py_8()
1511 .gap_2()
1512 .items_center()
1513 .justify_center()
1514 .elevation_3(cx)
1515 .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
1516 cx.emit(DismissEvent);
1517 }))
1518 .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
1519 window.focus(&this.focus_handle, cx);
1520 }))
1521 .child(self.render_icon(cx))
1522 .child(prompt)
1523 }
1524}
1525
1526fn markdown_styles(window: &Window, cx: &App) -> MarkdownStyle {
1527 let settings = ThemeSettings::get_global(cx);
1528 let colors = cx.theme().colors();
1529
1530 let mut text_style = window.text_style();
1531 text_style.refine(&TextStyleRefinement {
1532 font_family: Some(settings.ui_font.family.clone()),
1533 font_fallbacks: settings.ui_font.fallbacks.clone(),
1534 font_features: Some(settings.ui_font.features.clone()),
1535 font_size: Some(settings.ui_font_size(cx).into()),
1536 line_height: Some(relative(1.5)),
1537 color: Some(colors.text_muted),
1538 ..Default::default()
1539 });
1540
1541 MarkdownStyle {
1542 base_text_style: text_style.clone(),
1543 syntax: cx.theme().syntax().clone(),
1544 selection_background_color: colors.element_selection_background,
1545 heading_level_styles: Some(HeadingLevelStyles {
1546 h1: Some(TextStyleRefinement {
1547 font_size: Some(rems(1.15).into()),
1548 ..Default::default()
1549 }),
1550 h2: Some(TextStyleRefinement {
1551 font_size: Some(rems(1.1).into()),
1552 ..Default::default()
1553 }),
1554 h3: Some(TextStyleRefinement {
1555 font_size: Some(rems(1.05).into()),
1556 ..Default::default()
1557 }),
1558 h4: Some(TextStyleRefinement {
1559 font_size: Some(rems(1.).into()),
1560 ..Default::default()
1561 }),
1562 h5: Some(TextStyleRefinement {
1563 font_size: Some(rems(0.95).into()),
1564 ..Default::default()
1565 }),
1566 h6: Some(TextStyleRefinement {
1567 font_size: Some(rems(0.875).into()),
1568 ..Default::default()
1569 }),
1570 }),
1571 inline_code: TextStyleRefinement {
1572 font_family: Some(settings.buffer_font.family.clone()),
1573 font_fallbacks: settings.buffer_font.fallbacks.clone(),
1574 font_features: Some(settings.buffer_font.features.clone()),
1575 font_size: Some(settings.buffer_font_size(cx).into()),
1576 background_color: Some(colors.editor_foreground.opacity(0.08)),
1577 ..Default::default()
1578 },
1579 link: TextStyleRefinement {
1580 background_color: Some(colors.editor_foreground.opacity(0.025)),
1581 color: Some(colors.text_accent),
1582 underline: Some(UnderlineStyle {
1583 color: Some(colors.text_accent.opacity(0.5)),
1584 thickness: px(1.),
1585 ..Default::default()
1586 }),
1587 ..Default::default()
1588 },
1589 ..Default::default()
1590 }
1591}
1592
1593/// An extension-based language model.
1594pub struct ExtensionLanguageModel {
1595 extension: WasmExtension,
1596 model_info: LlmModelInfo,
1597 provider_id: LanguageModelProviderId,
1598 provider_name: LanguageModelProviderName,
1599 provider_info: LlmProviderInfo,
1600 request_limiter: RateLimiter,
1601 cache_config: Option<LanguageModelCacheConfiguration>,
1602}
1603
1604impl LanguageModel for ExtensionLanguageModel {
1605 fn id(&self) -> LanguageModelId {
1606 LanguageModelId::from(self.model_info.id.clone())
1607 }
1608
1609 fn name(&self) -> LanguageModelName {
1610 LanguageModelName::from(self.model_info.name.clone())
1611 }
1612
1613 fn provider_id(&self) -> LanguageModelProviderId {
1614 self.provider_id.clone()
1615 }
1616
1617 fn provider_name(&self) -> LanguageModelProviderName {
1618 self.provider_name.clone()
1619 }
1620
1621 fn telemetry_id(&self) -> String {
1622 format!("{}/{}", self.provider_info.id, self.model_info.id)
1623 }
1624
1625 fn supports_images(&self) -> bool {
1626 self.model_info.capabilities.supports_images
1627 }
1628
1629 fn supports_tools(&self) -> bool {
1630 self.model_info.capabilities.supports_tools
1631 }
1632
1633 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1634 match choice {
1635 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1636 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1637 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1638 }
1639 }
1640
1641 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1642 match self.model_info.capabilities.tool_input_format {
1643 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1644 LlmToolInputFormat::JsonSchemaSubset => LanguageModelToolSchemaFormat::JsonSchemaSubset,
1645 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1646 }
1647 }
1648
1649 fn max_token_count(&self) -> u64 {
1650 self.model_info.max_token_count
1651 }
1652
1653 fn max_output_tokens(&self) -> Option<u64> {
1654 self.model_info.max_output_tokens
1655 }
1656
1657 fn count_tokens(
1658 &self,
1659 request: LanguageModelRequest,
1660 cx: &App,
1661 ) -> BoxFuture<'static, Result<u64>> {
1662 let extension = self.extension.clone();
1663 let provider_id = self.provider_info.id.clone();
1664 let model_id = self.model_info.id.clone();
1665
1666 let wit_request = convert_request_to_wit(request);
1667
1668 cx.background_spawn(async move {
1669 extension
1670 .call({
1671 let provider_id = provider_id.clone();
1672 let model_id = model_id.clone();
1673 let wit_request = wit_request.clone();
1674 |ext, store| {
1675 async move {
1676 let count = ext
1677 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1678 .await?
1679 .map_err(|e| anyhow!("{}", e))?;
1680 Ok(count)
1681 }
1682 .boxed()
1683 }
1684 })
1685 .await?
1686 })
1687 .boxed()
1688 }
1689
1690 fn stream_completion(
1691 &self,
1692 request: LanguageModelRequest,
1693 _cx: &AsyncApp,
1694 ) -> BoxFuture<
1695 'static,
1696 Result<
1697 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1698 LanguageModelCompletionError,
1699 >,
1700 > {
1701 let extension = self.extension.clone();
1702 let provider_id = self.provider_info.id.clone();
1703 let model_id = self.model_info.id.clone();
1704
1705 let wit_request = convert_request_to_wit(request);
1706
1707 let future = self.request_limiter.stream(async move {
1708 // Start the stream
1709 let stream_id_result = extension
1710 .call({
1711 let provider_id = provider_id.clone();
1712 let model_id = model_id.clone();
1713 let wit_request = wit_request.clone();
1714 |ext, store| {
1715 async move {
1716 let id = ext
1717 .call_llm_stream_completion_start(
1718 store,
1719 &provider_id,
1720 &model_id,
1721 &wit_request,
1722 )
1723 .await?
1724 .map_err(|e| anyhow!("{}", e))?;
1725 Ok(id)
1726 }
1727 .boxed()
1728 }
1729 })
1730 .await;
1731
1732 let stream_id = stream_id_result
1733 .map_err(LanguageModelCompletionError::Other)?
1734 .map_err(LanguageModelCompletionError::Other)?;
1735
1736 // Create a stream that polls for events
1737 let stream = futures::stream::unfold(
1738 (extension.clone(), stream_id, false),
1739 move |(extension, stream_id, done)| async move {
1740 if done {
1741 return None;
1742 }
1743
1744 let result = extension
1745 .call({
1746 let stream_id = stream_id.clone();
1747 |ext, store| {
1748 async move {
1749 let event = ext
1750 .call_llm_stream_completion_next(store, &stream_id)
1751 .await?
1752 .map_err(|e| anyhow!("{}", e))?;
1753 Ok(event)
1754 }
1755 .boxed()
1756 }
1757 })
1758 .await
1759 .and_then(|inner| inner);
1760
1761 match result {
1762 Ok(Some(event)) => {
1763 let converted = convert_completion_event(event);
1764 let is_done =
1765 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1766 Some((converted, (extension, stream_id, is_done)))
1767 }
1768 Ok(None) => {
1769 // Stream complete, close it
1770 let _ = extension
1771 .call({
1772 let stream_id = stream_id.clone();
1773 |ext, store| {
1774 async move {
1775 ext.call_llm_stream_completion_close(store, &stream_id)
1776 .await?;
1777 Ok::<(), anyhow::Error>(())
1778 }
1779 .boxed()
1780 }
1781 })
1782 .await;
1783 None
1784 }
1785 Err(e) => Some((
1786 Err(LanguageModelCompletionError::Other(e)),
1787 (extension, stream_id, true),
1788 )),
1789 }
1790 },
1791 );
1792
1793 Ok(stream.boxed())
1794 });
1795
1796 async move { Ok(future.await?.boxed()) }.boxed()
1797 }
1798
1799 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1800 self.cache_config.clone()
1801 }
1802}
1803
1804fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1805 use language_model::{MessageContent, Role};
1806
1807 let messages: Vec<LlmRequestMessage> = request
1808 .messages
1809 .into_iter()
1810 .map(|msg| {
1811 let role = match msg.role {
1812 Role::User => LlmMessageRole::User,
1813 Role::Assistant => LlmMessageRole::Assistant,
1814 Role::System => LlmMessageRole::System,
1815 };
1816
1817 let content: Vec<LlmMessageContent> = msg
1818 .content
1819 .into_iter()
1820 .map(|c| match c {
1821 MessageContent::Text(text) => LlmMessageContent::Text(text),
1822 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1823 source: image.source.to_string(),
1824 width: image.size.map(|s| s.width.0 as u32),
1825 height: image.size.map(|s| s.height.0 as u32),
1826 }),
1827 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1828 id: tool_use.id.to_string(),
1829 name: tool_use.name.to_string(),
1830 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1831 is_input_complete: tool_use.is_input_complete,
1832 thought_signature: tool_use.thought_signature,
1833 }),
1834 MessageContent::ToolResult(tool_result) => {
1835 let content = match tool_result.content {
1836 language_model::LanguageModelToolResultContent::Text(text) => {
1837 LlmToolResultContent::Text(text.to_string())
1838 }
1839 language_model::LanguageModelToolResultContent::Image(image) => {
1840 LlmToolResultContent::Image(LlmImageData {
1841 source: image.source.to_string(),
1842 width: image.size.map(|s| s.width.0 as u32),
1843 height: image.size.map(|s| s.height.0 as u32),
1844 })
1845 }
1846 };
1847 LlmMessageContent::ToolResult(LlmToolResult {
1848 tool_use_id: tool_result.tool_use_id.to_string(),
1849 tool_name: tool_result.tool_name.to_string(),
1850 is_error: tool_result.is_error,
1851 content,
1852 })
1853 }
1854 MessageContent::Thinking { text, signature } => {
1855 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1856 }
1857 MessageContent::RedactedThinking(data) => {
1858 LlmMessageContent::RedactedThinking(data)
1859 }
1860 })
1861 .collect();
1862
1863 LlmRequestMessage {
1864 role,
1865 content,
1866 cache: msg.cache,
1867 }
1868 })
1869 .collect();
1870
1871 let tools: Vec<LlmToolDefinition> = request
1872 .tools
1873 .into_iter()
1874 .map(|tool| LlmToolDefinition {
1875 name: tool.name,
1876 description: tool.description,
1877 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1878 })
1879 .collect();
1880
1881 let tool_choice = request.tool_choice.map(|tc| match tc {
1882 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1883 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1884 LanguageModelToolChoice::None => LlmToolChoice::None,
1885 });
1886
1887 LlmCompletionRequest {
1888 messages,
1889 tools,
1890 tool_choice,
1891 stop_sequences: request.stop,
1892 temperature: request.temperature,
1893 thinking_allowed: request.thinking_allowed,
1894 max_tokens: None,
1895 }
1896}
1897
1898fn convert_completion_event(
1899 event: LlmCompletionEvent,
1900) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1901 match event {
1902 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1903 message_id: String::new(),
1904 }),
1905 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1906 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1907 text: thinking.text,
1908 signature: thinking.signature,
1909 }),
1910 LlmCompletionEvent::RedactedThinking(data) => {
1911 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1912 }
1913 LlmCompletionEvent::ToolUse(tool_use) => {
1914 let raw_input = tool_use.input.clone();
1915 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1916 Ok(LanguageModelCompletionEvent::ToolUse(
1917 LanguageModelToolUse {
1918 id: LanguageModelToolUseId::from(tool_use.id),
1919 name: tool_use.name.into(),
1920 raw_input,
1921 input,
1922 is_input_complete: tool_use.is_input_complete,
1923 thought_signature: tool_use.thought_signature,
1924 },
1925 ))
1926 }
1927 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1928 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1929 id: LanguageModelToolUseId::from(error.id),
1930 tool_name: error.tool_name.into(),
1931 raw_input: error.raw_input.into(),
1932 json_parse_error: error.error,
1933 })
1934 }
1935 LlmCompletionEvent::Stop(reason) => {
1936 let stop_reason = match reason {
1937 LlmStopReason::EndTurn => StopReason::EndTurn,
1938 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1939 LlmStopReason::ToolUse => StopReason::ToolUse,
1940 LlmStopReason::Refusal => StopReason::Refusal,
1941 };
1942 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1943 }
1944 LlmCompletionEvent::Usage(usage) => {
1945 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1946 input_tokens: usage.input_tokens,
1947 output_tokens: usage.output_tokens,
1948 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1949 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1950 }))
1951 }
1952 LlmCompletionEvent::ReasoningDetails(json) => {
1953 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1954 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1955 ))
1956 }
1957 }
1958}