llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::LEGACY_LLM_EXTENSION_IDS;
   3use crate::wasm_host::WasmExtension;
   4use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
   5use collections::HashSet;
   6
   7use crate::wasm_host::wit::{
   8    LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData,
   9    LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage,
  10    LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat,
  11    LlmToolResult, LlmToolResultContent, LlmToolUse,
  12};
  13use anyhow::{Result, anyhow};
  14use collections::HashMap;
  15use credentials_provider::CredentialsProvider;
  16use extension::{LanguageModelAuthConfig, OAuthConfig};
  17use futures::future::BoxFuture;
  18use futures::stream::BoxStream;
  19use futures::{FutureExt, StreamExt};
  20
  21use gpui::{
  22    AnyView, App, AsyncApp, ClipboardItem, DismissEvent, Entity, EventEmitter, FocusHandle,
  23    Focusable, MouseDownEvent, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window,
  24    WindowBounds, WindowOptions, point, prelude::*, px,
  25};
  26use language_model::tool_schema::LanguageModelToolSchemaFormat;
  27use language_model::{
  28    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  29    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  30    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  31    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  32    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, RateLimiter, StopReason,
  33    TokenUsage,
  34};
  35use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
  36use settings::Settings;
  37use std::sync::Arc;
  38use theme::ThemeSettings;
  39use ui::{
  40    ButtonLike, ButtonLink, Checkbox, ConfiguredApiCard, SpinnerLabel, ToggleState, Vector,
  41    VectorName, prelude::*,
  42};
  43use ui_input::InputField;
  44use util::ResultExt as _;
  45use workspace::Workspace;
  46use workspace::oauth_device_flow_modal::{
  47    OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
  48};
  49
  50/// An extension-based language model provider.
  51pub struct ExtensionLanguageModelProvider {
  52    pub extension: WasmExtension,
  53    pub provider_info: LlmProviderInfo,
  54    icon_path: Option<SharedString>,
  55    auth_config: Option<LanguageModelAuthConfig>,
  56    state: Entity<ExtensionLlmProviderState>,
  57}
  58
  59pub struct ExtensionLlmProviderState {
  60    is_authenticated: bool,
  61    available_models: Vec<LlmModelInfo>,
  62    /// Cache configurations for each model, keyed by model ID.
  63    cache_configs: HashMap<String, LlmCacheConfiguration>,
  64    /// Set of env var names that are allowed to be read for this provider.
  65    allowed_env_vars: HashSet<String>,
  66    /// If authenticated via env var, which one was used.
  67    env_var_name_used: Option<String>,
  68}
  69
  70impl EventEmitter<()> for ExtensionLlmProviderState {}
  71
  72impl ExtensionLanguageModelProvider {
  73    pub fn new(
  74        extension: WasmExtension,
  75        provider_info: LlmProviderInfo,
  76        models: Vec<LlmModelInfo>,
  77        cache_configs: HashMap<String, LlmCacheConfiguration>,
  78        is_authenticated: bool,
  79        icon_path: Option<SharedString>,
  80        auth_config: Option<LanguageModelAuthConfig>,
  81        cx: &mut App,
  82    ) -> Self {
  83        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  84
  85        // Build set of allowed env vars for this provider
  86        let settings = ExtensionSettings::get_global(cx);
  87        let is_legacy_extension =
  88            LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
  89
  90        let mut allowed_env_vars = HashSet::default();
  91        if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
  92            for env_var_name in env_vars {
  93                let key = format!("{}:{}", provider_id_string, env_var_name);
  94                // For legacy extensions, auto-allow if env var is set (migration will persist this)
  95                let env_var_is_set = std::env::var(env_var_name)
  96                    .map(|v| !v.is_empty())
  97                    .unwrap_or(false);
  98                if settings.allowed_env_var_providers.contains(key.as_str())
  99                    || (is_legacy_extension && env_var_is_set)
 100                {
 101                    allowed_env_vars.insert(env_var_name.clone());
 102                }
 103            }
 104        }
 105
 106        // Check if any allowed env var is set
 107        let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
 108            if let Ok(value) = std::env::var(env_var_name) {
 109                if !value.is_empty() {
 110                    return Some(env_var_name.clone());
 111                }
 112            }
 113            None
 114        });
 115
 116        let is_authenticated = if env_var_name_used.is_some() {
 117            true
 118        } else {
 119            is_authenticated
 120        };
 121
 122        let state = cx.new(|_| ExtensionLlmProviderState {
 123            is_authenticated,
 124            available_models: models,
 125            cache_configs,
 126            allowed_env_vars,
 127            env_var_name_used,
 128        });
 129
 130        Self {
 131            extension,
 132            provider_info,
 133            icon_path,
 134            auth_config,
 135            state,
 136        }
 137    }
 138
 139    fn provider_id_string(&self) -> String {
 140        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 141    }
 142
 143    /// The credential key used for storing the API key in the system keychain.
 144    fn credential_key(&self) -> String {
 145        format!("extension-llm-{}", self.provider_id_string())
 146    }
 147
 148    fn create_model(
 149        &self,
 150        model_info: &LlmModelInfo,
 151        cache_configs: &HashMap<String, LlmCacheConfiguration>,
 152    ) -> Arc<dyn LanguageModel> {
 153        let cache_config =
 154            cache_configs
 155                .get(&model_info.id)
 156                .map(|config| LanguageModelCacheConfiguration {
 157                    max_cache_anchors: config.max_cache_anchors as usize,
 158                    should_speculate: false,
 159                    min_total_token: config.min_total_token_count,
 160                });
 161
 162        Arc::new(ExtensionLanguageModel {
 163            extension: self.extension.clone(),
 164            model_info: model_info.clone(),
 165            provider_id: self.id(),
 166            provider_name: self.name(),
 167            provider_info: self.provider_info.clone(),
 168            request_limiter: RateLimiter::new(4),
 169            cache_config,
 170        })
 171    }
 172}
 173
 174impl LanguageModelProvider for ExtensionLanguageModelProvider {
 175    fn id(&self) -> LanguageModelProviderId {
 176        LanguageModelProviderId::from(self.provider_id_string())
 177    }
 178
 179    fn name(&self) -> LanguageModelProviderName {
 180        LanguageModelProviderName::from(self.provider_info.name.clone())
 181    }
 182
 183    fn icon(&self) -> ui::IconName {
 184        ui::IconName::ZedAssistant
 185    }
 186
 187    fn icon_path(&self) -> Option<SharedString> {
 188        self.icon_path.clone()
 189    }
 190
 191    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 192        let state = self.state.read(cx);
 193        state
 194            .available_models
 195            .iter()
 196            .find(|m| m.is_default)
 197            .or_else(|| state.available_models.first())
 198            .map(|model_info| self.create_model(model_info, &state.cache_configs))
 199    }
 200
 201    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 202        let state = self.state.read(cx);
 203        state
 204            .available_models
 205            .iter()
 206            .find(|m| m.is_default_fast)
 207            .map(|model_info| self.create_model(model_info, &state.cache_configs))
 208    }
 209
 210    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 211        let state = self.state.read(cx);
 212        state
 213            .available_models
 214            .iter()
 215            .map(|model_info| self.create_model(model_info, &state.cache_configs))
 216            .collect()
 217    }
 218
 219    fn is_authenticated(&self, cx: &App) -> bool {
 220        // First check cached state
 221        if self.state.read(cx).is_authenticated {
 222            return true;
 223        }
 224
 225        // Also check env var dynamically (in case settings changed after provider creation)
 226        if let Some(ref auth_config) = self.auth_config {
 227            if let Some(ref env_vars) = auth_config.env_vars {
 228                let provider_id_string = self.provider_id_string();
 229                let settings = ExtensionSettings::get_global(cx);
 230                let is_legacy_extension =
 231                    LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
 232
 233                for env_var_name in env_vars {
 234                    let key = format!("{}:{}", provider_id_string, env_var_name);
 235                    // For legacy extensions, auto-allow if env var is set
 236                    let env_var_is_set = std::env::var(env_var_name)
 237                        .map(|v| !v.is_empty())
 238                        .unwrap_or(false);
 239                    if settings.allowed_env_var_providers.contains(key.as_str())
 240                        || (is_legacy_extension && env_var_is_set)
 241                    {
 242                        if let Ok(value) = std::env::var(env_var_name) {
 243                            if !value.is_empty() {
 244                                return true;
 245                            }
 246                        }
 247                    }
 248                }
 249            }
 250        }
 251
 252        false
 253    }
 254
 255    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 256        // Check if already authenticated via is_authenticated
 257        if self.is_authenticated(cx) {
 258            return Task::ready(Ok(()));
 259        }
 260
 261        // Not authenticated - return error indicating credentials not found
 262        Task::ready(Err(AuthenticateError::CredentialsNotFound))
 263    }
 264
 265    fn configuration_view(
 266        &self,
 267        target_agent: ConfigurationViewTargetAgent,
 268        window: &mut Window,
 269        cx: &mut App,
 270    ) -> AnyView {
 271        let credential_key = self.credential_key();
 272        let extension = self.extension.clone();
 273        let extension_provider_id = self.provider_info.id.clone();
 274        let full_provider_id = self.provider_id_string();
 275        let state = self.state.clone();
 276        let auth_config = self.auth_config.clone();
 277
 278        let icon_path = self.icon_path.clone();
 279        cx.new(|cx| {
 280            ExtensionProviderConfigurationView::new(
 281                credential_key,
 282                extension,
 283                extension_provider_id,
 284                full_provider_id,
 285                auth_config,
 286                state,
 287                icon_path,
 288                target_agent,
 289                window,
 290                cx,
 291            )
 292        })
 293        .into()
 294    }
 295
 296    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 297        let extension = self.extension.clone();
 298        let provider_id = self.provider_info.id.clone();
 299        let state = self.state.clone();
 300        let credential_key = self.credential_key();
 301
 302        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 303
 304        cx.spawn(async move |cx| {
 305            // Delete from system keychain
 306            credentials_provider
 307                .delete_credentials(&credential_key, cx)
 308                .await
 309                .log_err();
 310
 311            // Call extension's reset_credentials
 312            let result = extension
 313                .call(|extension, store| {
 314                    async move {
 315                        extension
 316                            .call_llm_provider_reset_credentials(store, &provider_id)
 317                            .await
 318                    }
 319                    .boxed()
 320                })
 321                .await;
 322
 323            // Update state
 324            cx.update(|cx| {
 325                state.update(cx, |state, _| {
 326                    state.is_authenticated = false;
 327                });
 328            })?;
 329
 330            match result {
 331                Ok(Ok(Ok(()))) => Ok(()),
 332                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 333                Ok(Err(e)) => Err(e),
 334                Err(e) => Err(e),
 335            }
 336        })
 337    }
 338}
 339
 340impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 341    type ObservableEntity = ExtensionLlmProviderState;
 342
 343    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 344        Some(self.state.clone())
 345    }
 346
 347    fn subscribe<T: 'static>(
 348        &self,
 349        cx: &mut Context<T>,
 350        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 351    ) -> Option<Subscription> {
 352        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 353    }
 354}
 355
 356/// Configuration view for extension-based LLM providers.
 357struct ExtensionProviderConfigurationView {
 358    credential_key: String,
 359    extension: WasmExtension,
 360    extension_provider_id: String,
 361    full_provider_id: String,
 362    auth_config: Option<LanguageModelAuthConfig>,
 363    state: Entity<ExtensionLlmProviderState>,
 364    settings_markdown: Option<Entity<Markdown>>,
 365    api_key_editor: Entity<InputField>,
 366    loading_settings: bool,
 367    loading_credentials: bool,
 368    oauth_in_progress: bool,
 369    oauth_error: Option<String>,
 370    icon_path: Option<SharedString>,
 371    target_agent: ConfigurationViewTargetAgent,
 372    _subscriptions: Vec<Subscription>,
 373}
 374
 375impl ExtensionProviderConfigurationView {
 376    fn new(
 377        credential_key: String,
 378        extension: WasmExtension,
 379        extension_provider_id: String,
 380        full_provider_id: String,
 381        auth_config: Option<LanguageModelAuthConfig>,
 382        state: Entity<ExtensionLlmProviderState>,
 383        icon_path: Option<SharedString>,
 384        target_agent: ConfigurationViewTargetAgent,
 385        window: &mut Window,
 386        cx: &mut Context<Self>,
 387    ) -> Self {
 388        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 389            cx.notify();
 390        });
 391
 392        let credential_label = auth_config
 393            .as_ref()
 394            .and_then(|c| c.credential_label.clone())
 395            .unwrap_or_else(|| "API Key".to_string());
 396
 397        let api_key_editor = cx.new(|cx| {
 398            InputField::new(window, cx, "Enter API key and hit enter").label(credential_label)
 399        });
 400
 401        let mut this = Self {
 402            credential_key,
 403            extension,
 404            extension_provider_id,
 405            full_provider_id,
 406            auth_config,
 407            state,
 408            settings_markdown: None,
 409            api_key_editor,
 410            loading_settings: true,
 411            loading_credentials: true,
 412            oauth_in_progress: false,
 413            oauth_error: None,
 414            icon_path,
 415            target_agent,
 416            _subscriptions: vec![state_subscription],
 417        };
 418
 419        this.load_settings_text(cx);
 420        this.load_credentials(cx);
 421        this
 422    }
 423
 424    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 425        let extension = self.extension.clone();
 426        let provider_id = self.extension_provider_id.clone();
 427
 428        cx.spawn(async move |this, cx| {
 429            let result = extension
 430                .call({
 431                    let provider_id = provider_id.clone();
 432                    |ext, store| {
 433                        async move {
 434                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 435                                .await
 436                        }
 437                        .boxed()
 438                    }
 439                })
 440                .await;
 441
 442            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 443
 444            this.update(cx, |this, cx| {
 445                this.loading_settings = false;
 446                if let Some(text) = settings_text {
 447                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 448                    this.settings_markdown = Some(markdown);
 449                }
 450                cx.notify();
 451            })
 452            .log_err();
 453        })
 454        .detach();
 455    }
 456
 457    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 458        let credential_key = self.credential_key.clone();
 459        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 460        let state = self.state.clone();
 461
 462        // Check if we should use env var (already set in state during provider construction)
 463        let using_env_var = self.state.read(cx).env_var_name_used.is_some();
 464
 465        cx.spawn(async move |this, cx| {
 466            // If using env var, we're already authenticated
 467            if using_env_var {
 468                this.update(cx, |this, cx| {
 469                    this.loading_credentials = false;
 470                    cx.notify();
 471                })
 472                .log_err();
 473                return;
 474            }
 475
 476            let credentials = credentials_provider
 477                .read_credentials(&credential_key, cx)
 478                .await
 479                .log_err()
 480                .flatten();
 481
 482            let has_credentials = credentials.is_some();
 483
 484            // Update authentication state based on stored credentials
 485            cx.update(|cx| {
 486                state.update(cx, |state, cx| {
 487                    state.is_authenticated = has_credentials;
 488                    cx.notify();
 489                });
 490            })
 491            .log_err();
 492
 493            this.update(cx, |this, cx| {
 494                this.loading_credentials = false;
 495                cx.notify();
 496            })
 497            .log_err();
 498        })
 499        .detach();
 500    }
 501
 502    fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
 503        let full_provider_id = self.full_provider_id.clone();
 504        let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
 505
 506        let state = self.state.clone();
 507        let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
 508
 509        // Update settings file
 510        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
 511            move |settings, _| {
 512                let allowed = settings
 513                    .extension
 514                    .allowed_env_var_providers
 515                    .get_or_insert_with(Vec::new);
 516
 517                if currently_allowed {
 518                    allowed.retain(|id| id.as_ref() != settings_key.as_ref());
 519                } else {
 520                    if !allowed
 521                        .iter()
 522                        .any(|id| id.as_ref() == settings_key.as_ref())
 523                    {
 524                        allowed.push(settings_key.clone());
 525                    }
 526                }
 527            }
 528        });
 529
 530        // Update local state
 531        let new_allowed = !currently_allowed;
 532
 533        state.update(cx, |state, cx| {
 534            if new_allowed {
 535                state.allowed_env_vars.insert(env_var_name.clone());
 536                // Check if this env var is set and update env_var_name_used
 537                if let Ok(value) = std::env::var(&env_var_name) {
 538                    if !value.is_empty() && state.env_var_name_used.is_none() {
 539                        state.env_var_name_used = Some(env_var_name.clone());
 540                        state.is_authenticated = true;
 541                    }
 542                }
 543            } else {
 544                state.allowed_env_vars.remove(&env_var_name);
 545                // If this was the env var being used, clear it and find another
 546                if state.env_var_name_used.as_ref() == Some(&env_var_name) {
 547                    state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
 548                        if let Ok(value) = std::env::var(var) {
 549                            if !value.is_empty() {
 550                                return Some(var.clone());
 551                            }
 552                        }
 553                        None
 554                    });
 555                    if state.env_var_name_used.is_none() {
 556                        // No env var auth available, need to check keychain
 557                        state.is_authenticated = false;
 558                    }
 559                }
 560            }
 561            cx.notify();
 562        });
 563
 564        // If all env vars are being disabled, reload credentials from keychain
 565        if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
 566            self.reload_keychain_credentials(cx);
 567        }
 568
 569        cx.notify();
 570    }
 571
 572    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 573        let credential_key = self.credential_key.clone();
 574        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 575        let state = self.state.clone();
 576
 577        cx.spawn(async move |_this, cx| {
 578            let credentials = credentials_provider
 579                .read_credentials(&credential_key, cx)
 580                .await
 581                .log_err()
 582                .flatten();
 583
 584            let has_credentials = credentials.is_some();
 585
 586            cx.update(|cx| {
 587                state.update(cx, |state, cx| {
 588                    state.is_authenticated = has_credentials;
 589                    cx.notify();
 590                });
 591            })
 592            .log_err();
 593        })
 594        .detach();
 595    }
 596
 597    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 598        let api_key = self.api_key_editor.read(cx).text(cx);
 599        if api_key.is_empty() {
 600            return;
 601        }
 602
 603        // Clear the editor
 604        self.api_key_editor
 605            .update(cx, |input, cx| input.clear(window, cx));
 606
 607        let credential_key = self.credential_key.clone();
 608        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 609        let state = self.state.clone();
 610
 611        cx.spawn(async move |_this, cx| {
 612            // Store in system keychain
 613            credentials_provider
 614                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 615                .await
 616                .log_err();
 617
 618            // Update state to authenticated
 619            cx.update(|cx| {
 620                state.update(cx, |state, cx| {
 621                    state.is_authenticated = true;
 622                    cx.notify();
 623                });
 624            })
 625            .log_err();
 626        })
 627        .detach();
 628    }
 629
 630    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 631        // Clear the editor
 632        self.api_key_editor
 633            .update(cx, |input, cx| input.clear(window, cx));
 634
 635        let credential_key = self.credential_key.clone();
 636        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 637        let state = self.state.clone();
 638
 639        cx.spawn(async move |_this, cx| {
 640            // Delete from system keychain
 641            credentials_provider
 642                .delete_credentials(&credential_key, cx)
 643                .await
 644                .log_err();
 645
 646            // Update state to unauthenticated
 647            cx.update(|cx| {
 648                state.update(cx, |state, cx| {
 649                    state.is_authenticated = false;
 650                    cx.notify();
 651                });
 652            })
 653            .log_err();
 654        })
 655        .detach();
 656    }
 657
 658    fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 659        if self.oauth_in_progress {
 660            return;
 661        }
 662
 663        self.oauth_in_progress = true;
 664        self.oauth_error = None;
 665        cx.notify();
 666
 667        let extension = self.extension.clone();
 668        let provider_id = self.extension_provider_id.clone();
 669        let state = self.state.clone();
 670        let icon_path = self.icon_path.clone();
 671        let this_handle = cx.weak_entity();
 672        let use_popup_window = self.is_edit_prediction_mode();
 673
 674        // Get current window bounds for positioning popup
 675        let current_window_center = window.bounds().center();
 676
 677        // For workspace modal mode, find the workspace window
 678        let workspace_window = if !use_popup_window {
 679            log::info!("OAuth: Looking for workspace window");
 680            let ws = window.window_handle().downcast::<Workspace>().or_else(|| {
 681                log::info!("OAuth: Current window is not a workspace, searching other windows");
 682                cx.windows()
 683                    .into_iter()
 684                    .find_map(|window_handle| window_handle.downcast::<Workspace>())
 685            });
 686
 687            if ws.is_none() {
 688                log::error!("OAuth: Could not find any workspace window");
 689                self.oauth_in_progress = false;
 690                self.oauth_error =
 691                    Some("Could not access workspace to show sign-in modal".to_string());
 692                cx.notify();
 693                return;
 694            }
 695            ws
 696        } else {
 697            None
 698        };
 699
 700        log::info!(
 701            "OAuth: Using {} mode",
 702            if use_popup_window {
 703                "popup window"
 704            } else {
 705                "workspace modal"
 706            }
 707        );
 708        let state = state.downgrade();
 709        cx.spawn(async move |_this, cx| {
 710            // Step 1: Start device flow - get prompt info from extension
 711            let start_result = extension
 712                .call({
 713                    let provider_id = provider_id.clone();
 714                    |ext, store| {
 715                        async move {
 716                            ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
 717                                .await
 718                        }
 719                        .boxed()
 720                    }
 721                })
 722                .await;
 723
 724            log::info!(
 725                "OAuth: Device flow start result: {:?}",
 726                start_result.is_ok()
 727            );
 728            let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
 729                Ok(Ok(Ok(info))) => {
 730                    log::info!(
 731                        "OAuth: Got device flow prompt info, user_code: {}",
 732                        info.user_code
 733                    );
 734                    info
 735                }
 736                Ok(Ok(Err(e))) => {
 737                    log::error!("OAuth: Device flow start failed: {}", e);
 738                    this_handle
 739                        .update(cx, |this, cx| {
 740                            this.oauth_in_progress = false;
 741                            this.oauth_error = Some(e);
 742                            cx.notify();
 743                        })
 744                        .log_err();
 745                    return;
 746                }
 747                Ok(Err(e)) | Err(e) => {
 748                    log::error!("OAuth: Device flow start error: {}", e);
 749                    this_handle
 750                        .update(cx, |this, cx| {
 751                            this.oauth_in_progress = false;
 752                            this.oauth_error = Some(e.to_string());
 753                            cx.notify();
 754                        })
 755                        .log_err();
 756                    return;
 757                }
 758            };
 759
 760            // Step 2: Create state entity and show the modal/window
 761            let modal_config = OAuthDeviceFlowModalConfig {
 762                user_code: prompt_info.user_code,
 763                verification_url: prompt_info.verification_url,
 764                headline: prompt_info.headline,
 765                description: prompt_info.description,
 766                connect_button_label: prompt_info.connect_button_label,
 767                success_headline: prompt_info.success_headline,
 768                success_message: prompt_info.success_message,
 769                icon_path,
 770            };
 771
 772            let flow_state: Option<Entity<OAuthDeviceFlowState>> = if use_popup_window {
 773                // Open a popup window like Copilot does
 774                log::info!("OAuth: Opening popup window");
 775                cx.update(|cx| {
 776                    let height = px(450.);
 777                    let width = px(350.);
 778                    let window_bounds = WindowBounds::Windowed(gpui::bounds(
 779                        current_window_center - point(height / 2.0, width / 2.0),
 780                        gpui::size(height, width),
 781                    ));
 782
 783                    let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config.clone()));
 784                    let flow_state_for_window = flow_state.clone();
 785
 786                    cx.open_window(
 787                        WindowOptions {
 788                            kind: gpui::WindowKind::PopUp,
 789                            window_bounds: Some(window_bounds),
 790                            is_resizable: false,
 791                            is_movable: true,
 792                            titlebar: Some(gpui::TitlebarOptions {
 793                                appears_transparent: true,
 794                                ..Default::default()
 795                            }),
 796                            ..Default::default()
 797                        },
 798                        |window, cx| {
 799                            cx.new(|cx| {
 800                                OAuthCodeVerificationWindow::new(
 801                                    modal_config,
 802                                    flow_state_for_window,
 803                                    window,
 804                                    cx,
 805                                )
 806                            })
 807                        },
 808                    )
 809                    .log_err();
 810
 811                    Some(flow_state)
 812                })
 813                .ok()
 814                .flatten()
 815            } else {
 816                // Use workspace modal
 817                log::info!("OAuth: Attempting to show modal in workspace window");
 818                workspace_window.as_ref().and_then(|ws| {
 819                    ws.update(cx, |workspace, window, cx| {
 820                        log::info!("OAuth: Inside workspace.update, creating modal");
 821                        window.activate_window();
 822                        let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
 823                        let flow_state_clone = flow_state.clone();
 824                        workspace.toggle_modal(window, cx, |_window, cx| {
 825                            log::info!("OAuth: Inside toggle_modal callback");
 826                            OAuthDeviceFlowModal::new(flow_state_clone, cx)
 827                        });
 828                        flow_state
 829                    })
 830                    .ok()
 831                })
 832            };
 833
 834            log::info!("OAuth: flow_state created: {:?}", flow_state.is_some());
 835            let Some(flow_state) = flow_state else {
 836                log::error!("OAuth: Failed to show sign-in modal/window");
 837                this_handle
 838                    .update(cx, |this, cx| {
 839                        this.oauth_in_progress = false;
 840                        this.oauth_error = Some("Failed to show sign-in modal".to_string());
 841                        cx.notify();
 842                    })
 843                    .log_err();
 844                return;
 845            };
 846            log::info!("OAuth: Modal/window shown successfully, starting poll");
 847
 848            // Step 3: Poll for authentication completion
 849            let poll_result = extension
 850                .call({
 851                    let provider_id = provider_id.clone();
 852                    |ext, store| {
 853                        async move {
 854                            ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
 855                                .await
 856                        }
 857                        .boxed()
 858                    }
 859                })
 860                .await;
 861
 862            match poll_result {
 863                Ok(Ok(Ok(()))) => {
 864                    // After successful auth, refresh the models list
 865                    let models_result = extension
 866                        .call({
 867                            let provider_id = provider_id.clone();
 868                            |ext, store| {
 869                                async move {
 870                                    ext.call_llm_provider_models(store, &provider_id).await
 871                                }
 872                                .boxed()
 873                            }
 874                        })
 875                        .await;
 876
 877                    let new_models: Vec<LlmModelInfo> = match models_result {
 878                        Ok(Ok(Ok(models))) => models,
 879                        _ => Vec::new(),
 880                    };
 881
 882                    state
 883                        .update(cx, |state, cx| {
 884                            state.is_authenticated = true;
 885                            state.available_models = new_models;
 886                            cx.notify();
 887                        })
 888                        .log_err();
 889
 890                    // Update flow state to show success
 891                    flow_state
 892                        .update(cx, |state, cx| {
 893                            state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
 894                        })
 895                        .log_err();
 896                }
 897                Ok(Ok(Err(e))) => {
 898                    log::error!("Device flow poll failed: {}", e);
 899                    flow_state
 900                        .update(cx, |state, cx| {
 901                            state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
 902                        })
 903                        .log_err();
 904                    this_handle
 905                        .update(cx, |this, cx| {
 906                            this.oauth_error = Some(e);
 907                            cx.notify();
 908                        })
 909                        .log_err();
 910                }
 911                Ok(Err(e)) | Err(e) => {
 912                    log::error!("Device flow poll error: {}", e);
 913                    let error_string = e.to_string();
 914                    flow_state
 915                        .update(cx, |state, cx| {
 916                            state.set_status(
 917                                OAuthDeviceFlowStatus::Failed(error_string.clone()),
 918                                cx,
 919                            );
 920                        })
 921                        .log_err();
 922                    this_handle
 923                        .update(cx, |this, cx| {
 924                            this.oauth_error = Some(error_string);
 925                            cx.notify();
 926                        })
 927                        .log_err();
 928                }
 929            };
 930
 931            this_handle
 932                .update(cx, |this, cx| {
 933                    this.oauth_in_progress = false;
 934                    cx.notify();
 935                })
 936                .log_err();
 937        })
 938        .detach();
 939    }
 940
 941    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 942        self.state.read(cx).is_authenticated
 943    }
 944
 945    fn has_oauth_config(&self) -> bool {
 946        self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
 947    }
 948
 949    fn oauth_config(&self) -> Option<&OAuthConfig> {
 950        self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
 951    }
 952
 953    fn has_api_key_config(&self) -> bool {
 954        // API key is available if there's a credential_label or no oauth-only config
 955        self.auth_config
 956            .as_ref()
 957            .map(|c| c.credential_label.is_some() || c.oauth.is_none())
 958            .unwrap_or(true)
 959    }
 960
 961    fn is_edit_prediction_mode(&self) -> bool {
 962        self.target_agent == ConfigurationViewTargetAgent::EditPrediction
 963    }
 964
 965    fn render_for_edit_prediction(
 966        &mut self,
 967        _window: &mut Window,
 968        cx: &mut Context<Self>,
 969    ) -> impl IntoElement {
 970        let is_loading = self.loading_settings || self.loading_credentials;
 971        let is_authenticated = self.is_authenticated(cx);
 972        let has_oauth = self.has_oauth_config();
 973
 974        // Helper to create the horizontal container layout matching Copilot
 975        let container = |description: SharedString, action: AnyElement| {
 976            h_flex()
 977                .pt_2p5()
 978                .w_full()
 979                .justify_between()
 980                .child(
 981                    v_flex()
 982                        .w_full()
 983                        .max_w_1_2()
 984                        .child(Label::new("Authenticate To Use"))
 985                        .child(
 986                            Label::new(description)
 987                                .color(Color::Muted)
 988                                .size(LabelSize::Small),
 989                        ),
 990                )
 991                .child(action)
 992        };
 993
 994        // Get the description from OAuth config or use a default
 995        let oauth_config = self.oauth_config();
 996        let description: SharedString = oauth_config
 997            .and_then(|c| c.sign_in_description.clone())
 998            .unwrap_or_else(|| "Sign in to authenticate with this provider.".to_string())
 999            .into();
1000
1001        if is_loading {
1002            return container(
1003                description,
1004                Button::new("loading", "Loading...")
1005                    .style(ButtonStyle::Outlined)
1006                    .disabled(true)
1007                    .into_any_element(),
1008            )
1009            .into_any_element();
1010        }
1011
1012        // If authenticated, show the configured card
1013        if is_authenticated {
1014            let (status_label, button_label) = if has_oauth {
1015                ("Authorized", "Sign Out")
1016            } else {
1017                ("API key configured", "Reset Key")
1018            };
1019
1020            return ConfiguredApiCard::new(status_label)
1021                .button_label(button_label)
1022                .on_click(cx.listener(|this, _, window, cx| {
1023                    this.reset_api_key(window, cx);
1024                }))
1025                .into_any_element();
1026        }
1027
1028        // Not authenticated - show sign in button
1029        if has_oauth {
1030            let button_label = oauth_config
1031                .and_then(|c| c.sign_in_button_label.clone())
1032                .unwrap_or_else(|| "Sign In".to_string());
1033            let button_icon = oauth_config
1034                .and_then(|c| c.sign_in_button_icon.as_ref())
1035                .and_then(|icon_name| match icon_name.as_str() {
1036                    "github" => Some(ui::IconName::Github),
1037                    _ => None,
1038                });
1039
1040            let oauth_in_progress = self.oauth_in_progress;
1041
1042            let mut button = Button::new("oauth-sign-in", button_label)
1043                .size(ButtonSize::Medium)
1044                .style(ButtonStyle::Outlined)
1045                .disabled(oauth_in_progress)
1046                .on_click(cx.listener(|this, _, window, cx| {
1047                    this.start_oauth_sign_in(window, cx);
1048                }));
1049
1050            if let Some(icon) = button_icon {
1051                button = button
1052                    .icon(icon)
1053                    .icon_position(ui::IconPosition::Start)
1054                    .icon_size(ui::IconSize::Small)
1055                    .icon_color(Color::Muted);
1056            }
1057
1058            return container(description, button.into_any_element()).into_any_element();
1059        }
1060
1061        // Fallback for API key only providers - show a simple message
1062        container(
1063            description,
1064            Button::new("configure", "Configure")
1065                .size(ButtonSize::Medium)
1066                .style(ButtonStyle::Outlined)
1067                .disabled(true)
1068                .into_any_element(),
1069        )
1070        .into_any_element()
1071    }
1072}
1073
1074impl Render for ExtensionProviderConfigurationView {
1075    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1076        if self.is_edit_prediction_mode() {
1077            return self
1078                .render_for_edit_prediction(window, cx)
1079                .into_any_element();
1080        }
1081
1082        let is_loading = self.loading_settings || self.loading_credentials;
1083        let is_authenticated = self.is_authenticated(cx);
1084        let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
1085        let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
1086        let has_oauth = self.has_oauth_config();
1087        let has_api_key = self.has_api_key_config();
1088
1089        if is_loading {
1090            return h_flex()
1091                .gap_2()
1092                .child(
1093                    h_flex()
1094                        .w_2()
1095                        .child(SpinnerLabel::sand().size(LabelSize::Small)),
1096                )
1097                .child(LoadingLabel::new("Loading").size(LabelSize::Small))
1098                .into_any_element();
1099        }
1100
1101        let mut content = v_flex().size_full().gap_2();
1102
1103        if let Some(markdown) = &self.settings_markdown {
1104            content = content.text_sm().child(MarkdownElement::new(
1105                markdown.clone(),
1106                markdown_styles(window, cx),
1107            ));
1108        }
1109
1110        if let Some(auth_config) = &self.auth_config {
1111            if let Some(env_vars) = &auth_config.env_vars {
1112                for env_var_name in env_vars {
1113                    let is_allowed = allowed_env_vars.contains(env_var_name);
1114                    let checkbox_label =
1115                        format!("Read API key from {} environment variable.", env_var_name);
1116                    let env_var_for_click = env_var_name.clone();
1117
1118                    content = content.child(
1119                        Checkbox::new(
1120                            SharedString::from(format!("env-var-{}", env_var_name)),
1121                            if is_allowed {
1122                                ToggleState::Selected
1123                            } else {
1124                                ToggleState::Unselected
1125                            },
1126                        )
1127                        .label(checkbox_label)
1128                        .on_click(cx.listener(
1129                            move |this, _, _window, cx| {
1130                                this.toggle_env_var_permission(env_var_for_click.clone(), cx);
1131                            },
1132                        )),
1133                    );
1134                }
1135
1136                if let Some(used_var) = &env_var_name_used {
1137                    content = content.child(
1138                        ConfiguredApiCard::new(format!(
1139                            "API key set in {} environment variable",
1140                            used_var
1141                        ))
1142                        .tooltip_label(format!(
1143                            "To reset this API key, unset the {} environment variable.",
1144                            used_var
1145                        ))
1146                        .disabled(true),
1147                    );
1148
1149                    return content.into_any_element();
1150                }
1151            }
1152        }
1153
1154        if is_authenticated && env_var_name_used.is_none() {
1155            let (status_label, button_label) = if has_oauth && !has_api_key {
1156                ("Signed in", "Sign Out")
1157            } else {
1158                ("API key configured", "Reset Key")
1159            };
1160
1161            content = content.child(
1162                ConfiguredApiCard::new(status_label)
1163                    .button_label(button_label)
1164                    .on_click(cx.listener(|this, _, window, cx| {
1165                        this.reset_api_key(window, cx);
1166                    })),
1167            );
1168
1169            return content.into_any_element();
1170        }
1171
1172        // Not authenticated - show available auth options
1173        if env_var_name_used.is_none() {
1174            // Render OAuth sign-in button if configured
1175            if has_oauth {
1176                let oauth_config = self.oauth_config();
1177                let button_label = oauth_config
1178                    .and_then(|c| c.sign_in_button_label.clone())
1179                    .unwrap_or_else(|| "Sign In".to_string());
1180                let button_icon = oauth_config
1181                    .and_then(|c| c.sign_in_button_icon.as_ref())
1182                    .and_then(|icon_name| match icon_name.as_str() {
1183                        "github" => Some(ui::IconName::Github),
1184                        _ => None,
1185                    });
1186
1187                let oauth_in_progress = self.oauth_in_progress;
1188
1189                let oauth_error = self.oauth_error.clone();
1190
1191                let mut button = Button::new("oauth-sign-in", button_label)
1192                    .full_width()
1193                    .style(ButtonStyle::Outlined)
1194                    .disabled(oauth_in_progress)
1195                    .on_click(cx.listener(|this, _, window, cx| {
1196                        this.start_oauth_sign_in(window, cx);
1197                    }));
1198                if let Some(icon) = button_icon {
1199                    button = button
1200                        .icon(icon)
1201                        .icon_position(IconPosition::Start)
1202                        .icon_size(IconSize::Small)
1203                        .icon_color(Color::Muted);
1204                }
1205
1206                content = content.child(
1207                    v_flex()
1208                        .gap_2()
1209                        .child(button)
1210                        .when(oauth_in_progress, |this| {
1211                            this.child(
1212                                Label::new("Sign-in in progress...")
1213                                    .size(LabelSize::Small)
1214                                    .color(Color::Muted),
1215                            )
1216                        })
1217                        .when_some(oauth_error, |this, error| {
1218                            this.child(
1219                                v_flex()
1220                                    .gap_1()
1221                                    .child(
1222                                        h_flex()
1223                                            .gap_2()
1224                                            .child(
1225                                                Icon::new(IconName::Warning)
1226                                                    .color(Color::Error)
1227                                                    .size(IconSize::Small),
1228                                            )
1229                                            .child(
1230                                                Label::new("Authentication failed")
1231                                                    .color(Color::Error)
1232                                                    .size(LabelSize::Small),
1233                                            ),
1234                                    )
1235                                    .child(
1236                                        div().pl_6().child(
1237                                            Label::new(error)
1238                                                .color(Color::Error)
1239                                                .size(LabelSize::Small),
1240                                        ),
1241                                    ),
1242                            )
1243                        }),
1244                );
1245            }
1246
1247            // Render API key input if configured (and we have both options, show a separator)
1248            if has_api_key {
1249                if has_oauth {
1250                    content = content.child(
1251                        h_flex()
1252                            .gap_2()
1253                            .items_center()
1254                            .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant))
1255                            .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1256                            .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant)),
1257                    );
1258                }
1259
1260                content = content.child(
1261                    div()
1262                        .on_action(cx.listener(Self::save_api_key))
1263                        .child(self.api_key_editor.clone()),
1264                );
1265            }
1266        }
1267
1268        if self.extension_provider_id == "openai" {
1269            content = content.child(
1270                h_flex()
1271                    .gap_1()
1272                    .child(
1273                        Icon::new(IconName::Info)
1274                            .size(IconSize::XSmall)
1275                            .color(Color::Muted),
1276                    )
1277                    .child(
1278                        Label::new("Zed also supports OpenAI-compatible models.")
1279                            .size(LabelSize::Small)
1280                            .color(Color::Muted),
1281                    )
1282                    .child(
1283                        ButtonLink::new(
1284                            "Learn More",
1285                            "https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers",
1286                        )
1287                        .label_size(LabelSize::Small),
1288                    ),
1289            );
1290        }
1291
1292        content.into_any_element()
1293    }
1294}
1295
1296impl Focusable for ExtensionProviderConfigurationView {
1297    fn focus_handle(&self, cx: &App) -> FocusHandle {
1298        self.api_key_editor.read(cx).focus_handle(cx)
1299    }
1300}
1301
1302/// A popup window for OAuth device flow, similar to CopilotCodeVerification.
1303/// This is used when in edit prediction mode to avoid moving the settings panel behind.
1304pub struct OAuthCodeVerificationWindow {
1305    config: OAuthDeviceFlowModalConfig,
1306    status: OAuthDeviceFlowStatus,
1307    connect_clicked: bool,
1308    focus_handle: FocusHandle,
1309    _subscription: Option<Subscription>,
1310}
1311
1312impl Focusable for OAuthCodeVerificationWindow {
1313    fn focus_handle(&self, _: &App) -> FocusHandle {
1314        self.focus_handle.clone()
1315    }
1316}
1317
1318impl EventEmitter<DismissEvent> for OAuthCodeVerificationWindow {}
1319
1320impl OAuthCodeVerificationWindow {
1321    pub fn new(
1322        config: OAuthDeviceFlowModalConfig,
1323        state: Entity<OAuthDeviceFlowState>,
1324        window: &mut Window,
1325        cx: &mut Context<Self>,
1326    ) -> Self {
1327        window.on_window_should_close(cx, |window, cx| {
1328            if let Some(this) = window.root::<OAuthCodeVerificationWindow>().flatten() {
1329                this.update(cx, |_, cx| {
1330                    cx.emit(DismissEvent);
1331                });
1332            }
1333            true
1334        });
1335        cx.subscribe_in(
1336            &cx.entity(),
1337            window,
1338            |_, _, _: &DismissEvent, window, _cx| {
1339                window.remove_window();
1340            },
1341        )
1342        .detach();
1343
1344        let subscription = cx.observe(&state, |this, state, cx| {
1345            let status = state.read(cx).status.clone();
1346            this.status = status;
1347            cx.notify();
1348        });
1349
1350        Self {
1351            config,
1352            status: state.read(cx).status.clone(),
1353            connect_clicked: false,
1354            focus_handle: cx.focus_handle(),
1355            _subscription: Some(subscription),
1356        }
1357    }
1358
1359    fn render_icon(&self, cx: &mut Context<Self>) -> impl IntoElement {
1360        let icon_color = Color::Custom(cx.theme().colors().icon);
1361        let icon_size = rems(2.5);
1362        let plus_size = rems(0.875);
1363        let plus_color = cx.theme().colors().icon.opacity(0.5);
1364
1365        if let Some(icon_path) = &self.config.icon_path {
1366            h_flex()
1367                .gap_2()
1368                .items_center()
1369                .child(
1370                    Icon::from_external_svg(icon_path.clone())
1371                        .size(IconSize::Custom(icon_size))
1372                        .color(icon_color),
1373                )
1374                .child(
1375                    gpui::svg()
1376                        .size(plus_size)
1377                        .path("icons/plus.svg")
1378                        .text_color(plus_color),
1379                )
1380                .child(Vector::new(VectorName::ZedLogo, icon_size, icon_size).color(icon_color))
1381                .into_any_element()
1382        } else {
1383            Vector::new(VectorName::ZedLogo, icon_size, icon_size)
1384                .color(icon_color)
1385                .into_any_element()
1386        }
1387    }
1388
1389    fn render_device_code(&self, cx: &mut Context<Self>) -> impl IntoElement {
1390        let user_code = self.config.user_code.clone();
1391        let copied = cx
1392            .read_from_clipboard()
1393            .map(|item| item.text().as_ref() == Some(&user_code))
1394            .unwrap_or(false);
1395        let user_code_for_click = user_code.clone();
1396
1397        ButtonLike::new("copy-button")
1398            .full_width()
1399            .style(ButtonStyle::Tinted(ui::TintColor::Accent))
1400            .size(ButtonSize::Medium)
1401            .child(
1402                h_flex()
1403                    .w_full()
1404                    .p_1()
1405                    .justify_between()
1406                    .child(Label::new(user_code))
1407                    .child(Label::new(if copied { "Copied!" } else { "Copy" })),
1408            )
1409            .on_click(move |_, window, cx| {
1410                cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone()));
1411                window.refresh();
1412            })
1413    }
1414
1415    fn render_prompting_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1416        let connect_button_label: String = if self.connect_clicked {
1417            "Waiting for connection…".to_string()
1418        } else {
1419            self.config.connect_button_label.clone()
1420        };
1421        let verification_url = self.config.verification_url.clone();
1422
1423        v_flex()
1424            .flex_1()
1425            .gap_2p5()
1426            .items_center()
1427            .text_center()
1428            .child(Headline::new(self.config.headline.clone()).size(HeadlineSize::Large))
1429            .child(Label::new(self.config.description.clone()).color(Color::Muted))
1430            .child(self.render_device_code(cx))
1431            .child(
1432                Label::new("Paste this code after clicking the button below.").color(Color::Muted),
1433            )
1434            .child(
1435                v_flex()
1436                    .w_full()
1437                    .gap_1()
1438                    .child(
1439                        Button::new("connect-button", connect_button_label)
1440                            .full_width()
1441                            .style(ButtonStyle::Outlined)
1442                            .size(ButtonSize::Medium)
1443                            .on_click(cx.listener(move |this, _, _window, cx| {
1444                                cx.open_url(&verification_url);
1445                                this.connect_clicked = true;
1446                            })),
1447                    )
1448                    .child(
1449                        Button::new("cancel-button", "Cancel")
1450                            .full_width()
1451                            .size(ButtonSize::Medium)
1452                            .on_click(cx.listener(|_, _, _, cx| {
1453                                cx.emit(DismissEvent);
1454                            })),
1455                    ),
1456            )
1457    }
1458
1459    fn render_authorized_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1460        v_flex()
1461            .gap_2()
1462            .text_center()
1463            .justify_center()
1464            .child(Headline::new(self.config.success_headline.clone()).size(HeadlineSize::Large))
1465            .child(Label::new(self.config.success_message.clone()).color(Color::Muted))
1466            .child(
1467                Button::new("done-button", "Done")
1468                    .full_width()
1469                    .style(ButtonStyle::Outlined)
1470                    .size(ButtonSize::Medium)
1471                    .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1472            )
1473    }
1474
1475    fn render_failed_modal(&self, error: &str, cx: &mut Context<Self>) -> impl IntoElement {
1476        v_flex()
1477            .gap_2()
1478            .text_center()
1479            .justify_center()
1480            .child(Headline::new("Authorization Failed").size(HeadlineSize::Large))
1481            .child(Label::new(error.to_string()).color(Color::Error))
1482            .child(
1483                Button::new("close-button", "Close")
1484                    .full_width()
1485                    .size(ButtonSize::Medium)
1486                    .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1487            )
1488    }
1489}
1490
1491impl Render for OAuthCodeVerificationWindow {
1492    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1493        let prompt = match &self.status {
1494            OAuthDeviceFlowStatus::Prompting | OAuthDeviceFlowStatus::WaitingForAuthorization => {
1495                self.render_prompting_modal(cx).into_any_element()
1496            }
1497            OAuthDeviceFlowStatus::Authorized => {
1498                self.render_authorized_modal(cx).into_any_element()
1499            }
1500            OAuthDeviceFlowStatus::Failed(error) => {
1501                self.render_failed_modal(error, cx).into_any_element()
1502            }
1503        };
1504
1505        v_flex()
1506            .id("oauth_code_verification")
1507            .track_focus(&self.focus_handle(cx))
1508            .size_full()
1509            .px_4()
1510            .py_8()
1511            .gap_2()
1512            .items_center()
1513            .justify_center()
1514            .elevation_3(cx)
1515            .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
1516                cx.emit(DismissEvent);
1517            }))
1518            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
1519                window.focus(&this.focus_handle, cx);
1520            }))
1521            .child(self.render_icon(cx))
1522            .child(prompt)
1523    }
1524}
1525
1526fn markdown_styles(window: &Window, cx: &App) -> MarkdownStyle {
1527    let settings = ThemeSettings::get_global(cx);
1528    let colors = cx.theme().colors();
1529
1530    let mut text_style = window.text_style();
1531    text_style.refine(&TextStyleRefinement {
1532        font_family: Some(settings.ui_font.family.clone()),
1533        font_fallbacks: settings.ui_font.fallbacks.clone(),
1534        font_features: Some(settings.ui_font.features.clone()),
1535        font_size: Some(settings.ui_font_size(cx).into()),
1536        line_height: Some(relative(1.5)),
1537        color: Some(colors.text_muted),
1538        ..Default::default()
1539    });
1540
1541    MarkdownStyle {
1542        base_text_style: text_style.clone(),
1543        syntax: cx.theme().syntax().clone(),
1544        selection_background_color: colors.element_selection_background,
1545        heading_level_styles: Some(HeadingLevelStyles {
1546            h1: Some(TextStyleRefinement {
1547                font_size: Some(rems(1.15).into()),
1548                ..Default::default()
1549            }),
1550            h2: Some(TextStyleRefinement {
1551                font_size: Some(rems(1.1).into()),
1552                ..Default::default()
1553            }),
1554            h3: Some(TextStyleRefinement {
1555                font_size: Some(rems(1.05).into()),
1556                ..Default::default()
1557            }),
1558            h4: Some(TextStyleRefinement {
1559                font_size: Some(rems(1.).into()),
1560                ..Default::default()
1561            }),
1562            h5: Some(TextStyleRefinement {
1563                font_size: Some(rems(0.95).into()),
1564                ..Default::default()
1565            }),
1566            h6: Some(TextStyleRefinement {
1567                font_size: Some(rems(0.875).into()),
1568                ..Default::default()
1569            }),
1570        }),
1571        inline_code: TextStyleRefinement {
1572            font_family: Some(settings.buffer_font.family.clone()),
1573            font_fallbacks: settings.buffer_font.fallbacks.clone(),
1574            font_features: Some(settings.buffer_font.features.clone()),
1575            font_size: Some(settings.buffer_font_size(cx).into()),
1576            background_color: Some(colors.editor_foreground.opacity(0.08)),
1577            ..Default::default()
1578        },
1579        link: TextStyleRefinement {
1580            background_color: Some(colors.editor_foreground.opacity(0.025)),
1581            color: Some(colors.text_accent),
1582            underline: Some(UnderlineStyle {
1583                color: Some(colors.text_accent.opacity(0.5)),
1584                thickness: px(1.),
1585                ..Default::default()
1586            }),
1587            ..Default::default()
1588        },
1589        ..Default::default()
1590    }
1591}
1592
1593/// An extension-based language model.
1594pub struct ExtensionLanguageModel {
1595    extension: WasmExtension,
1596    model_info: LlmModelInfo,
1597    provider_id: LanguageModelProviderId,
1598    provider_name: LanguageModelProviderName,
1599    provider_info: LlmProviderInfo,
1600    request_limiter: RateLimiter,
1601    cache_config: Option<LanguageModelCacheConfiguration>,
1602}
1603
1604impl LanguageModel for ExtensionLanguageModel {
1605    fn id(&self) -> LanguageModelId {
1606        LanguageModelId::from(self.model_info.id.clone())
1607    }
1608
1609    fn name(&self) -> LanguageModelName {
1610        LanguageModelName::from(self.model_info.name.clone())
1611    }
1612
1613    fn provider_id(&self) -> LanguageModelProviderId {
1614        self.provider_id.clone()
1615    }
1616
1617    fn provider_name(&self) -> LanguageModelProviderName {
1618        self.provider_name.clone()
1619    }
1620
1621    fn telemetry_id(&self) -> String {
1622        format!("{}/{}", self.provider_info.id, self.model_info.id)
1623    }
1624
1625    fn supports_images(&self) -> bool {
1626        self.model_info.capabilities.supports_images
1627    }
1628
1629    fn supports_tools(&self) -> bool {
1630        self.model_info.capabilities.supports_tools
1631    }
1632
1633    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1634        match choice {
1635            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1636            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1637            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1638        }
1639    }
1640
1641    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1642        match self.model_info.capabilities.tool_input_format {
1643            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1644            LlmToolInputFormat::JsonSchemaSubset => LanguageModelToolSchemaFormat::JsonSchemaSubset,
1645            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1646        }
1647    }
1648
1649    fn max_token_count(&self) -> u64 {
1650        self.model_info.max_token_count
1651    }
1652
1653    fn max_output_tokens(&self) -> Option<u64> {
1654        self.model_info.max_output_tokens
1655    }
1656
1657    fn count_tokens(
1658        &self,
1659        request: LanguageModelRequest,
1660        cx: &App,
1661    ) -> BoxFuture<'static, Result<u64>> {
1662        let extension = self.extension.clone();
1663        let provider_id = self.provider_info.id.clone();
1664        let model_id = self.model_info.id.clone();
1665
1666        let wit_request = convert_request_to_wit(request);
1667
1668        cx.background_spawn(async move {
1669            extension
1670                .call({
1671                    let provider_id = provider_id.clone();
1672                    let model_id = model_id.clone();
1673                    let wit_request = wit_request.clone();
1674                    |ext, store| {
1675                        async move {
1676                            let count = ext
1677                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1678                                .await?
1679                                .map_err(|e| anyhow!("{}", e))?;
1680                            Ok(count)
1681                        }
1682                        .boxed()
1683                    }
1684                })
1685                .await?
1686        })
1687        .boxed()
1688    }
1689
1690    fn stream_completion(
1691        &self,
1692        request: LanguageModelRequest,
1693        _cx: &AsyncApp,
1694    ) -> BoxFuture<
1695        'static,
1696        Result<
1697            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1698            LanguageModelCompletionError,
1699        >,
1700    > {
1701        let extension = self.extension.clone();
1702        let provider_id = self.provider_info.id.clone();
1703        let model_id = self.model_info.id.clone();
1704
1705        let wit_request = convert_request_to_wit(request);
1706
1707        let future = self.request_limiter.stream(async move {
1708            // Start the stream
1709            let stream_id_result = extension
1710                .call({
1711                    let provider_id = provider_id.clone();
1712                    let model_id = model_id.clone();
1713                    let wit_request = wit_request.clone();
1714                    |ext, store| {
1715                        async move {
1716                            let id = ext
1717                                .call_llm_stream_completion_start(
1718                                    store,
1719                                    &provider_id,
1720                                    &model_id,
1721                                    &wit_request,
1722                                )
1723                                .await?
1724                                .map_err(|e| anyhow!("{}", e))?;
1725                            Ok(id)
1726                        }
1727                        .boxed()
1728                    }
1729                })
1730                .await;
1731
1732            let stream_id = stream_id_result
1733                .map_err(LanguageModelCompletionError::Other)?
1734                .map_err(LanguageModelCompletionError::Other)?;
1735
1736            // Create a stream that polls for events
1737            let stream = futures::stream::unfold(
1738                (extension.clone(), stream_id, false),
1739                move |(extension, stream_id, done)| async move {
1740                    if done {
1741                        return None;
1742                    }
1743
1744                    let result = extension
1745                        .call({
1746                            let stream_id = stream_id.clone();
1747                            |ext, store| {
1748                                async move {
1749                                    let event = ext
1750                                        .call_llm_stream_completion_next(store, &stream_id)
1751                                        .await?
1752                                        .map_err(|e| anyhow!("{}", e))?;
1753                                    Ok(event)
1754                                }
1755                                .boxed()
1756                            }
1757                        })
1758                        .await
1759                        .and_then(|inner| inner);
1760
1761                    match result {
1762                        Ok(Some(event)) => {
1763                            let converted = convert_completion_event(event);
1764                            let is_done =
1765                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1766                            Some((converted, (extension, stream_id, is_done)))
1767                        }
1768                        Ok(None) => {
1769                            // Stream complete, close it
1770                            let _ = extension
1771                                .call({
1772                                    let stream_id = stream_id.clone();
1773                                    |ext, store| {
1774                                        async move {
1775                                            ext.call_llm_stream_completion_close(store, &stream_id)
1776                                                .await?;
1777                                            Ok::<(), anyhow::Error>(())
1778                                        }
1779                                        .boxed()
1780                                    }
1781                                })
1782                                .await;
1783                            None
1784                        }
1785                        Err(e) => Some((
1786                            Err(LanguageModelCompletionError::Other(e)),
1787                            (extension, stream_id, true),
1788                        )),
1789                    }
1790                },
1791            );
1792
1793            Ok(stream.boxed())
1794        });
1795
1796        async move { Ok(future.await?.boxed()) }.boxed()
1797    }
1798
1799    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1800        self.cache_config.clone()
1801    }
1802}
1803
1804fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1805    use language_model::{MessageContent, Role};
1806
1807    let messages: Vec<LlmRequestMessage> = request
1808        .messages
1809        .into_iter()
1810        .map(|msg| {
1811            let role = match msg.role {
1812                Role::User => LlmMessageRole::User,
1813                Role::Assistant => LlmMessageRole::Assistant,
1814                Role::System => LlmMessageRole::System,
1815            };
1816
1817            let content: Vec<LlmMessageContent> = msg
1818                .content
1819                .into_iter()
1820                .map(|c| match c {
1821                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1822                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1823                        source: image.source.to_string(),
1824                        width: image.size.map(|s| s.width.0 as u32),
1825                        height: image.size.map(|s| s.height.0 as u32),
1826                    }),
1827                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1828                        id: tool_use.id.to_string(),
1829                        name: tool_use.name.to_string(),
1830                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1831                        is_input_complete: tool_use.is_input_complete,
1832                        thought_signature: tool_use.thought_signature,
1833                    }),
1834                    MessageContent::ToolResult(tool_result) => {
1835                        let content = match tool_result.content {
1836                            language_model::LanguageModelToolResultContent::Text(text) => {
1837                                LlmToolResultContent::Text(text.to_string())
1838                            }
1839                            language_model::LanguageModelToolResultContent::Image(image) => {
1840                                LlmToolResultContent::Image(LlmImageData {
1841                                    source: image.source.to_string(),
1842                                    width: image.size.map(|s| s.width.0 as u32),
1843                                    height: image.size.map(|s| s.height.0 as u32),
1844                                })
1845                            }
1846                        };
1847                        LlmMessageContent::ToolResult(LlmToolResult {
1848                            tool_use_id: tool_result.tool_use_id.to_string(),
1849                            tool_name: tool_result.tool_name.to_string(),
1850                            is_error: tool_result.is_error,
1851                            content,
1852                        })
1853                    }
1854                    MessageContent::Thinking { text, signature } => {
1855                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1856                    }
1857                    MessageContent::RedactedThinking(data) => {
1858                        LlmMessageContent::RedactedThinking(data)
1859                    }
1860                })
1861                .collect();
1862
1863            LlmRequestMessage {
1864                role,
1865                content,
1866                cache: msg.cache,
1867            }
1868        })
1869        .collect();
1870
1871    let tools: Vec<LlmToolDefinition> = request
1872        .tools
1873        .into_iter()
1874        .map(|tool| LlmToolDefinition {
1875            name: tool.name,
1876            description: tool.description,
1877            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1878        })
1879        .collect();
1880
1881    let tool_choice = request.tool_choice.map(|tc| match tc {
1882        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1883        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1884        LanguageModelToolChoice::None => LlmToolChoice::None,
1885    });
1886
1887    LlmCompletionRequest {
1888        messages,
1889        tools,
1890        tool_choice,
1891        stop_sequences: request.stop,
1892        temperature: request.temperature,
1893        thinking_allowed: request.thinking_allowed,
1894        max_tokens: None,
1895    }
1896}
1897
1898fn convert_completion_event(
1899    event: LlmCompletionEvent,
1900) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1901    match event {
1902        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1903            message_id: String::new(),
1904        }),
1905        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1906        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1907            text: thinking.text,
1908            signature: thinking.signature,
1909        }),
1910        LlmCompletionEvent::RedactedThinking(data) => {
1911            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1912        }
1913        LlmCompletionEvent::ToolUse(tool_use) => {
1914            let raw_input = tool_use.input.clone();
1915            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1916            Ok(LanguageModelCompletionEvent::ToolUse(
1917                LanguageModelToolUse {
1918                    id: LanguageModelToolUseId::from(tool_use.id),
1919                    name: tool_use.name.into(),
1920                    raw_input,
1921                    input,
1922                    is_input_complete: tool_use.is_input_complete,
1923                    thought_signature: tool_use.thought_signature,
1924                },
1925            ))
1926        }
1927        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1928            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1929                id: LanguageModelToolUseId::from(error.id),
1930                tool_name: error.tool_name.into(),
1931                raw_input: error.raw_input.into(),
1932                json_parse_error: error.error,
1933            })
1934        }
1935        LlmCompletionEvent::Stop(reason) => {
1936            let stop_reason = match reason {
1937                LlmStopReason::EndTurn => StopReason::EndTurn,
1938                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1939                LlmStopReason::ToolUse => StopReason::ToolUse,
1940                LlmStopReason::Refusal => StopReason::Refusal,
1941            };
1942            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1943        }
1944        LlmCompletionEvent::Usage(usage) => {
1945            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1946                input_tokens: usage.input_tokens,
1947                output_tokens: usage.output_tokens,
1948                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1949                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1950            }))
1951        }
1952        LlmCompletionEvent::ReasoningDetails(json) => {
1953            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1954                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1955            ))
1956        }
1957    }
1958}