llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::LEGACY_LLM_EXTENSION_IDS;
   3use crate::wasm_host::WasmExtension;
   4use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
   5use collections::HashSet;
   6
   7use crate::wasm_host::wit::{
   8    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   9    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
  10    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  11    LlmToolUse,
  12};
  13use anyhow::{Result, anyhow};
  14use credentials_provider::CredentialsProvider;
  15use extension::{LanguageModelAuthConfig, OAuthConfig};
  16use futures::future::BoxFuture;
  17use futures::stream::BoxStream;
  18use futures::{FutureExt, StreamExt};
  19
  20use gpui::{
  21    AnyView, App, AsyncApp, ClipboardItem, DismissEvent, Entity, EventEmitter, FocusHandle,
  22    Focusable, MouseDownEvent, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window,
  23    WindowBounds, WindowOptions, point, prelude::*, px,
  24};
  25use language_model::tool_schema::LanguageModelToolSchemaFormat;
  26use language_model::{
  27    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  28    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  29    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  30    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  31    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  32};
  33use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
  34use settings::Settings;
  35use std::sync::Arc;
  36use theme::ThemeSettings;
  37use ui::{
  38    ButtonLike, ButtonLink, Checkbox, ConfiguredApiCard, SpinnerLabel, ToggleState, Vector,
  39    VectorName, prelude::*,
  40};
  41use ui_input::InputField;
  42use util::ResultExt as _;
  43use workspace::Workspace;
  44use workspace::oauth_device_flow_modal::{
  45    OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
  46};
  47
  48/// An extension-based language model provider.
  49pub struct ExtensionLanguageModelProvider {
  50    pub extension: WasmExtension,
  51    pub provider_info: LlmProviderInfo,
  52    icon_path: Option<SharedString>,
  53    auth_config: Option<LanguageModelAuthConfig>,
  54    state: Entity<ExtensionLlmProviderState>,
  55}
  56
  57pub struct ExtensionLlmProviderState {
  58    is_authenticated: bool,
  59    available_models: Vec<LlmModelInfo>,
  60    /// Set of env var names that are allowed to be read for this provider.
  61    allowed_env_vars: HashSet<String>,
  62    /// If authenticated via env var, which one was used.
  63    env_var_name_used: Option<String>,
  64}
  65
  66impl EventEmitter<()> for ExtensionLlmProviderState {}
  67
  68impl ExtensionLanguageModelProvider {
  69    pub fn new(
  70        extension: WasmExtension,
  71        provider_info: LlmProviderInfo,
  72        models: Vec<LlmModelInfo>,
  73        is_authenticated: bool,
  74        icon_path: Option<SharedString>,
  75        auth_config: Option<LanguageModelAuthConfig>,
  76        cx: &mut App,
  77    ) -> Self {
  78        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  79
  80        // Build set of allowed env vars for this provider
  81        let settings = ExtensionSettings::get_global(cx);
  82        let is_legacy_extension =
  83            LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
  84
  85        let mut allowed_env_vars = HashSet::default();
  86        if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
  87            for env_var_name in env_vars {
  88                let key = format!("{}:{}", provider_id_string, env_var_name);
  89                // For legacy extensions, auto-allow if env var is set (migration will persist this)
  90                let env_var_is_set = std::env::var(env_var_name)
  91                    .map(|v| !v.is_empty())
  92                    .unwrap_or(false);
  93                if settings.allowed_env_var_providers.contains(key.as_str())
  94                    || (is_legacy_extension && env_var_is_set)
  95                {
  96                    allowed_env_vars.insert(env_var_name.clone());
  97                }
  98            }
  99        }
 100
 101        // Check if any allowed env var is set
 102        let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
 103            if let Ok(value) = std::env::var(env_var_name) {
 104                if !value.is_empty() {
 105                    return Some(env_var_name.clone());
 106                }
 107            }
 108            None
 109        });
 110
 111        let is_authenticated = if env_var_name_used.is_some() {
 112            true
 113        } else {
 114            is_authenticated
 115        };
 116
 117        let state = cx.new(|_| ExtensionLlmProviderState {
 118            is_authenticated,
 119            available_models: models,
 120            allowed_env_vars,
 121            env_var_name_used,
 122        });
 123
 124        Self {
 125            extension,
 126            provider_info,
 127            icon_path,
 128            auth_config,
 129            state,
 130        }
 131    }
 132
 133    fn provider_id_string(&self) -> String {
 134        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 135    }
 136
 137    /// The credential key used for storing the API key in the system keychain.
 138    fn credential_key(&self) -> String {
 139        format!("extension-llm-{}", self.provider_id_string())
 140    }
 141}
 142
 143impl LanguageModelProvider for ExtensionLanguageModelProvider {
 144    fn id(&self) -> LanguageModelProviderId {
 145        LanguageModelProviderId::from(self.provider_id_string())
 146    }
 147
 148    fn name(&self) -> LanguageModelProviderName {
 149        LanguageModelProviderName::from(self.provider_info.name.clone())
 150    }
 151
 152    fn icon(&self) -> ui::IconName {
 153        ui::IconName::ZedAssistant
 154    }
 155
 156    fn icon_path(&self) -> Option<SharedString> {
 157        self.icon_path.clone()
 158    }
 159
 160    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 161        let state = self.state.read(cx);
 162        state
 163            .available_models
 164            .iter()
 165            .find(|m| m.is_default)
 166            .or_else(|| state.available_models.first())
 167            .map(|model_info| {
 168                Arc::new(ExtensionLanguageModel {
 169                    extension: self.extension.clone(),
 170                    model_info: model_info.clone(),
 171                    provider_id: self.id(),
 172                    provider_name: self.name(),
 173                    provider_info: self.provider_info.clone(),
 174                }) as Arc<dyn LanguageModel>
 175            })
 176    }
 177
 178    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 179        let state = self.state.read(cx);
 180        state
 181            .available_models
 182            .iter()
 183            .find(|m| m.is_default_fast)
 184            .map(|model_info| {
 185                Arc::new(ExtensionLanguageModel {
 186                    extension: self.extension.clone(),
 187                    model_info: model_info.clone(),
 188                    provider_id: self.id(),
 189                    provider_name: self.name(),
 190                    provider_info: self.provider_info.clone(),
 191                }) as Arc<dyn LanguageModel>
 192            })
 193    }
 194
 195    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 196        let state = self.state.read(cx);
 197        state
 198            .available_models
 199            .iter()
 200            .map(|model_info| {
 201                Arc::new(ExtensionLanguageModel {
 202                    extension: self.extension.clone(),
 203                    model_info: model_info.clone(),
 204                    provider_id: self.id(),
 205                    provider_name: self.name(),
 206                    provider_info: self.provider_info.clone(),
 207                }) as Arc<dyn LanguageModel>
 208            })
 209            .collect()
 210    }
 211
 212    fn is_authenticated(&self, cx: &App) -> bool {
 213        // First check cached state
 214        if self.state.read(cx).is_authenticated {
 215            return true;
 216        }
 217
 218        // Also check env var dynamically (in case settings changed after provider creation)
 219        if let Some(ref auth_config) = self.auth_config {
 220            if let Some(ref env_vars) = auth_config.env_vars {
 221                let provider_id_string = self.provider_id_string();
 222                let settings = ExtensionSettings::get_global(cx);
 223                let is_legacy_extension =
 224                    LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
 225
 226                for env_var_name in env_vars {
 227                    let key = format!("{}:{}", provider_id_string, env_var_name);
 228                    // For legacy extensions, auto-allow if env var is set
 229                    let env_var_is_set = std::env::var(env_var_name)
 230                        .map(|v| !v.is_empty())
 231                        .unwrap_or(false);
 232                    if settings.allowed_env_var_providers.contains(key.as_str())
 233                        || (is_legacy_extension && env_var_is_set)
 234                    {
 235                        if let Ok(value) = std::env::var(env_var_name) {
 236                            if !value.is_empty() {
 237                                return true;
 238                            }
 239                        }
 240                    }
 241                }
 242            }
 243        }
 244
 245        false
 246    }
 247
 248    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 249        // Check if already authenticated via is_authenticated
 250        if self.is_authenticated(cx) {
 251            return Task::ready(Ok(()));
 252        }
 253
 254        // Not authenticated - return error indicating credentials not found
 255        Task::ready(Err(AuthenticateError::CredentialsNotFound))
 256    }
 257
 258    fn configuration_view(
 259        &self,
 260        target_agent: ConfigurationViewTargetAgent,
 261        window: &mut Window,
 262        cx: &mut App,
 263    ) -> AnyView {
 264        let credential_key = self.credential_key();
 265        let extension = self.extension.clone();
 266        let extension_provider_id = self.provider_info.id.clone();
 267        let full_provider_id = self.provider_id_string();
 268        let state = self.state.clone();
 269        let auth_config = self.auth_config.clone();
 270
 271        let icon_path = self.icon_path.clone();
 272        cx.new(|cx| {
 273            ExtensionProviderConfigurationView::new(
 274                credential_key,
 275                extension,
 276                extension_provider_id,
 277                full_provider_id,
 278                auth_config,
 279                state,
 280                icon_path,
 281                target_agent,
 282                window,
 283                cx,
 284            )
 285        })
 286        .into()
 287    }
 288
 289    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 290        let extension = self.extension.clone();
 291        let provider_id = self.provider_info.id.clone();
 292        let state = self.state.clone();
 293        let credential_key = self.credential_key();
 294
 295        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 296
 297        cx.spawn(async move |cx| {
 298            // Delete from system keychain
 299            credentials_provider
 300                .delete_credentials(&credential_key, cx)
 301                .await
 302                .log_err();
 303
 304            // Call extension's reset_credentials
 305            let result = extension
 306                .call(|extension, store| {
 307                    async move {
 308                        extension
 309                            .call_llm_provider_reset_credentials(store, &provider_id)
 310                            .await
 311                    }
 312                    .boxed()
 313                })
 314                .await;
 315
 316            // Update state
 317            cx.update(|cx| {
 318                state.update(cx, |state, _| {
 319                    state.is_authenticated = false;
 320                });
 321            })?;
 322
 323            match result {
 324                Ok(Ok(Ok(()))) => Ok(()),
 325                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 326                Ok(Err(e)) => Err(e),
 327                Err(e) => Err(e),
 328            }
 329        })
 330    }
 331}
 332
 333impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 334    type ObservableEntity = ExtensionLlmProviderState;
 335
 336    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 337        Some(self.state.clone())
 338    }
 339
 340    fn subscribe<T: 'static>(
 341        &self,
 342        cx: &mut Context<T>,
 343        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 344    ) -> Option<Subscription> {
 345        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 346    }
 347}
 348
 349/// Configuration view for extension-based LLM providers.
 350struct ExtensionProviderConfigurationView {
 351    credential_key: String,
 352    extension: WasmExtension,
 353    extension_provider_id: String,
 354    full_provider_id: String,
 355    auth_config: Option<LanguageModelAuthConfig>,
 356    state: Entity<ExtensionLlmProviderState>,
 357    settings_markdown: Option<Entity<Markdown>>,
 358    api_key_editor: Entity<InputField>,
 359    loading_settings: bool,
 360    loading_credentials: bool,
 361    oauth_in_progress: bool,
 362    oauth_error: Option<String>,
 363    icon_path: Option<SharedString>,
 364    target_agent: ConfigurationViewTargetAgent,
 365    _subscriptions: Vec<Subscription>,
 366}
 367
 368impl ExtensionProviderConfigurationView {
 369    fn new(
 370        credential_key: String,
 371        extension: WasmExtension,
 372        extension_provider_id: String,
 373        full_provider_id: String,
 374        auth_config: Option<LanguageModelAuthConfig>,
 375        state: Entity<ExtensionLlmProviderState>,
 376        icon_path: Option<SharedString>,
 377        target_agent: ConfigurationViewTargetAgent,
 378        window: &mut Window,
 379        cx: &mut Context<Self>,
 380    ) -> Self {
 381        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 382            cx.notify();
 383        });
 384
 385        let credential_label = auth_config
 386            .as_ref()
 387            .and_then(|c| c.credential_label.clone())
 388            .unwrap_or_else(|| "API Key".to_string());
 389
 390        let api_key_editor = cx.new(|cx| {
 391            InputField::new(window, cx, "Enter API key and hit enter").label(credential_label)
 392        });
 393
 394        let mut this = Self {
 395            credential_key,
 396            extension,
 397            extension_provider_id,
 398            full_provider_id,
 399            auth_config,
 400            state,
 401            settings_markdown: None,
 402            api_key_editor,
 403            loading_settings: true,
 404            loading_credentials: true,
 405            oauth_in_progress: false,
 406            oauth_error: None,
 407            icon_path,
 408            target_agent,
 409            _subscriptions: vec![state_subscription],
 410        };
 411
 412        this.load_settings_text(cx);
 413        this.load_credentials(cx);
 414        this
 415    }
 416
 417    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 418        let extension = self.extension.clone();
 419        let provider_id = self.extension_provider_id.clone();
 420
 421        cx.spawn(async move |this, cx| {
 422            let result = extension
 423                .call({
 424                    let provider_id = provider_id.clone();
 425                    |ext, store| {
 426                        async move {
 427                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 428                                .await
 429                        }
 430                        .boxed()
 431                    }
 432                })
 433                .await;
 434
 435            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 436
 437            this.update(cx, |this, cx| {
 438                this.loading_settings = false;
 439                if let Some(text) = settings_text {
 440                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 441                    this.settings_markdown = Some(markdown);
 442                }
 443                cx.notify();
 444            })
 445            .log_err();
 446        })
 447        .detach();
 448    }
 449
 450    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 451        let credential_key = self.credential_key.clone();
 452        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 453        let state = self.state.clone();
 454
 455        // Check if we should use env var (already set in state during provider construction)
 456        let using_env_var = self.state.read(cx).env_var_name_used.is_some();
 457
 458        cx.spawn(async move |this, cx| {
 459            // If using env var, we're already authenticated
 460            if using_env_var {
 461                this.update(cx, |this, cx| {
 462                    this.loading_credentials = false;
 463                    cx.notify();
 464                })
 465                .log_err();
 466                return;
 467            }
 468
 469            let credentials = credentials_provider
 470                .read_credentials(&credential_key, cx)
 471                .await
 472                .log_err()
 473                .flatten();
 474
 475            let has_credentials = credentials.is_some();
 476
 477            // Update authentication state based on stored credentials
 478            cx.update(|cx| {
 479                state.update(cx, |state, cx| {
 480                    state.is_authenticated = has_credentials;
 481                    cx.notify();
 482                });
 483            })
 484            .log_err();
 485
 486            this.update(cx, |this, cx| {
 487                this.loading_credentials = false;
 488                cx.notify();
 489            })
 490            .log_err();
 491        })
 492        .detach();
 493    }
 494
 495    fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
 496        let full_provider_id = self.full_provider_id.clone();
 497        let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
 498
 499        let state = self.state.clone();
 500        let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
 501
 502        // Update settings file
 503        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
 504            move |settings, _| {
 505                let allowed = settings
 506                    .extension
 507                    .allowed_env_var_providers
 508                    .get_or_insert_with(Vec::new);
 509
 510                if currently_allowed {
 511                    allowed.retain(|id| id.as_ref() != settings_key.as_ref());
 512                } else {
 513                    if !allowed
 514                        .iter()
 515                        .any(|id| id.as_ref() == settings_key.as_ref())
 516                    {
 517                        allowed.push(settings_key.clone());
 518                    }
 519                }
 520            }
 521        });
 522
 523        // Update local state
 524        let new_allowed = !currently_allowed;
 525
 526        state.update(cx, |state, cx| {
 527            if new_allowed {
 528                state.allowed_env_vars.insert(env_var_name.clone());
 529                // Check if this env var is set and update env_var_name_used
 530                if let Ok(value) = std::env::var(&env_var_name) {
 531                    if !value.is_empty() && state.env_var_name_used.is_none() {
 532                        state.env_var_name_used = Some(env_var_name.clone());
 533                        state.is_authenticated = true;
 534                    }
 535                }
 536            } else {
 537                state.allowed_env_vars.remove(&env_var_name);
 538                // If this was the env var being used, clear it and find another
 539                if state.env_var_name_used.as_ref() == Some(&env_var_name) {
 540                    state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
 541                        if let Ok(value) = std::env::var(var) {
 542                            if !value.is_empty() {
 543                                return Some(var.clone());
 544                            }
 545                        }
 546                        None
 547                    });
 548                    if state.env_var_name_used.is_none() {
 549                        // No env var auth available, need to check keychain
 550                        state.is_authenticated = false;
 551                    }
 552                }
 553            }
 554            cx.notify();
 555        });
 556
 557        // If all env vars are being disabled, reload credentials from keychain
 558        if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
 559            self.reload_keychain_credentials(cx);
 560        }
 561
 562        cx.notify();
 563    }
 564
 565    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 566        let credential_key = self.credential_key.clone();
 567        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 568        let state = self.state.clone();
 569
 570        cx.spawn(async move |_this, cx| {
 571            let credentials = credentials_provider
 572                .read_credentials(&credential_key, cx)
 573                .await
 574                .log_err()
 575                .flatten();
 576
 577            let has_credentials = credentials.is_some();
 578
 579            cx.update(|cx| {
 580                state.update(cx, |state, cx| {
 581                    state.is_authenticated = has_credentials;
 582                    cx.notify();
 583                });
 584            })
 585            .log_err();
 586        })
 587        .detach();
 588    }
 589
 590    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 591        let api_key = self.api_key_editor.read(cx).text(cx);
 592        if api_key.is_empty() {
 593            return;
 594        }
 595
 596        // Clear the editor
 597        self.api_key_editor
 598            .update(cx, |input, cx| input.clear(window, cx));
 599
 600        let credential_key = self.credential_key.clone();
 601        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 602        let state = self.state.clone();
 603
 604        cx.spawn(async move |_this, cx| {
 605            // Store in system keychain
 606            credentials_provider
 607                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 608                .await
 609                .log_err();
 610
 611            // Update state to authenticated
 612            cx.update(|cx| {
 613                state.update(cx, |state, cx| {
 614                    state.is_authenticated = true;
 615                    cx.notify();
 616                });
 617            })
 618            .log_err();
 619        })
 620        .detach();
 621    }
 622
 623    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 624        // Clear the editor
 625        self.api_key_editor
 626            .update(cx, |input, cx| input.clear(window, cx));
 627
 628        let credential_key = self.credential_key.clone();
 629        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 630        let state = self.state.clone();
 631
 632        cx.spawn(async move |_this, cx| {
 633            // Delete from system keychain
 634            credentials_provider
 635                .delete_credentials(&credential_key, cx)
 636                .await
 637                .log_err();
 638
 639            // Update state to unauthenticated
 640            cx.update(|cx| {
 641                state.update(cx, |state, cx| {
 642                    state.is_authenticated = false;
 643                    cx.notify();
 644                });
 645            })
 646            .log_err();
 647        })
 648        .detach();
 649    }
 650
 651    fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 652        if self.oauth_in_progress {
 653            return;
 654        }
 655
 656        self.oauth_in_progress = true;
 657        self.oauth_error = None;
 658        cx.notify();
 659
 660        let extension = self.extension.clone();
 661        let provider_id = self.extension_provider_id.clone();
 662        let state = self.state.clone();
 663        let icon_path = self.icon_path.clone();
 664        let this_handle = cx.weak_entity();
 665        let use_popup_window = self.is_edit_prediction_mode();
 666
 667        // Get current window bounds for positioning popup
 668        let current_window_center = window.bounds().center();
 669
 670        // For workspace modal mode, find the workspace window
 671        let workspace_window = if !use_popup_window {
 672            log::info!("OAuth: Looking for workspace window");
 673            let ws = window.window_handle().downcast::<Workspace>().or_else(|| {
 674                log::info!("OAuth: Current window is not a workspace, searching other windows");
 675                cx.windows()
 676                    .into_iter()
 677                    .find_map(|window_handle| window_handle.downcast::<Workspace>())
 678            });
 679
 680            if ws.is_none() {
 681                log::error!("OAuth: Could not find any workspace window");
 682                self.oauth_in_progress = false;
 683                self.oauth_error =
 684                    Some("Could not access workspace to show sign-in modal".to_string());
 685                cx.notify();
 686                return;
 687            }
 688            ws
 689        } else {
 690            None
 691        };
 692
 693        log::info!(
 694            "OAuth: Using {} mode",
 695            if use_popup_window {
 696                "popup window"
 697            } else {
 698                "workspace modal"
 699            }
 700        );
 701        let state = state.downgrade();
 702        cx.spawn(async move |_this, cx| {
 703            // Step 1: Start device flow - get prompt info from extension
 704            let start_result = extension
 705                .call({
 706                    let provider_id = provider_id.clone();
 707                    |ext, store| {
 708                        async move {
 709                            ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
 710                                .await
 711                        }
 712                        .boxed()
 713                    }
 714                })
 715                .await;
 716
 717            log::info!(
 718                "OAuth: Device flow start result: {:?}",
 719                start_result.is_ok()
 720            );
 721            let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
 722                Ok(Ok(Ok(info))) => {
 723                    log::info!(
 724                        "OAuth: Got device flow prompt info, user_code: {}",
 725                        info.user_code
 726                    );
 727                    info
 728                }
 729                Ok(Ok(Err(e))) => {
 730                    log::error!("OAuth: Device flow start failed: {}", e);
 731                    this_handle
 732                        .update(cx, |this, cx| {
 733                            this.oauth_in_progress = false;
 734                            this.oauth_error = Some(e);
 735                            cx.notify();
 736                        })
 737                        .log_err();
 738                    return;
 739                }
 740                Ok(Err(e)) | Err(e) => {
 741                    log::error!("OAuth: Device flow start error: {}", e);
 742                    this_handle
 743                        .update(cx, |this, cx| {
 744                            this.oauth_in_progress = false;
 745                            this.oauth_error = Some(e.to_string());
 746                            cx.notify();
 747                        })
 748                        .log_err();
 749                    return;
 750                }
 751            };
 752
 753            // Step 2: Create state entity and show the modal/window
 754            let modal_config = OAuthDeviceFlowModalConfig {
 755                user_code: prompt_info.user_code,
 756                verification_url: prompt_info.verification_url,
 757                headline: prompt_info.headline,
 758                description: prompt_info.description,
 759                connect_button_label: prompt_info.connect_button_label,
 760                success_headline: prompt_info.success_headline,
 761                success_message: prompt_info.success_message,
 762                icon_path,
 763            };
 764
 765            let flow_state: Option<Entity<OAuthDeviceFlowState>> = if use_popup_window {
 766                // Open a popup window like Copilot does
 767                log::info!("OAuth: Opening popup window");
 768                cx.update(|cx| {
 769                    let height = px(450.);
 770                    let width = px(350.);
 771                    let window_bounds = WindowBounds::Windowed(gpui::bounds(
 772                        current_window_center - point(height / 2.0, width / 2.0),
 773                        gpui::size(height, width),
 774                    ));
 775
 776                    let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config.clone()));
 777                    let flow_state_for_window = flow_state.clone();
 778
 779                    cx.open_window(
 780                        WindowOptions {
 781                            kind: gpui::WindowKind::PopUp,
 782                            window_bounds: Some(window_bounds),
 783                            is_resizable: false,
 784                            is_movable: true,
 785                            titlebar: Some(gpui::TitlebarOptions {
 786                                appears_transparent: true,
 787                                ..Default::default()
 788                            }),
 789                            ..Default::default()
 790                        },
 791                        |window, cx| {
 792                            cx.new(|cx| {
 793                                OAuthCodeVerificationWindow::new(
 794                                    modal_config,
 795                                    flow_state_for_window,
 796                                    window,
 797                                    cx,
 798                                )
 799                            })
 800                        },
 801                    )
 802                    .log_err();
 803
 804                    Some(flow_state)
 805                })
 806                .ok()
 807                .flatten()
 808            } else {
 809                // Use workspace modal
 810                log::info!("OAuth: Attempting to show modal in workspace window");
 811                workspace_window.as_ref().and_then(|ws| {
 812                    ws.update(cx, |workspace, window, cx| {
 813                        log::info!("OAuth: Inside workspace.update, creating modal");
 814                        window.activate_window();
 815                        let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
 816                        let flow_state_clone = flow_state.clone();
 817                        workspace.toggle_modal(window, cx, |_window, cx| {
 818                            log::info!("OAuth: Inside toggle_modal callback");
 819                            OAuthDeviceFlowModal::new(flow_state_clone, cx)
 820                        });
 821                        flow_state
 822                    })
 823                    .ok()
 824                })
 825            };
 826
 827            log::info!("OAuth: flow_state created: {:?}", flow_state.is_some());
 828            let Some(flow_state) = flow_state else {
 829                log::error!("OAuth: Failed to show sign-in modal/window");
 830                this_handle
 831                    .update(cx, |this, cx| {
 832                        this.oauth_in_progress = false;
 833                        this.oauth_error = Some("Failed to show sign-in modal".to_string());
 834                        cx.notify();
 835                    })
 836                    .log_err();
 837                return;
 838            };
 839            log::info!("OAuth: Modal/window shown successfully, starting poll");
 840
 841            // Step 3: Poll for authentication completion
 842            let poll_result = extension
 843                .call({
 844                    let provider_id = provider_id.clone();
 845                    |ext, store| {
 846                        async move {
 847                            ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
 848                                .await
 849                        }
 850                        .boxed()
 851                    }
 852                })
 853                .await;
 854
 855            match poll_result {
 856                Ok(Ok(Ok(()))) => {
 857                    // After successful auth, refresh the models list
 858                    let models_result = extension
 859                        .call({
 860                            let provider_id = provider_id.clone();
 861                            |ext, store| {
 862                                async move {
 863                                    ext.call_llm_provider_models(store, &provider_id).await
 864                                }
 865                                .boxed()
 866                            }
 867                        })
 868                        .await;
 869
 870                    let new_models: Vec<LlmModelInfo> = match models_result {
 871                        Ok(Ok(Ok(models))) => models,
 872                        _ => Vec::new(),
 873                    };
 874
 875                    state
 876                        .update(cx, |state, cx| {
 877                            state.is_authenticated = true;
 878                            state.available_models = new_models;
 879                            cx.notify();
 880                        })
 881                        .log_err();
 882
 883                    // Update flow state to show success
 884                    flow_state
 885                        .update(cx, |state, cx| {
 886                            state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
 887                        })
 888                        .log_err();
 889                }
 890                Ok(Ok(Err(e))) => {
 891                    log::error!("Device flow poll failed: {}", e);
 892                    flow_state
 893                        .update(cx, |state, cx| {
 894                            state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
 895                        })
 896                        .log_err();
 897                    this_handle
 898                        .update(cx, |this, cx| {
 899                            this.oauth_error = Some(e);
 900                            cx.notify();
 901                        })
 902                        .log_err();
 903                }
 904                Ok(Err(e)) | Err(e) => {
 905                    log::error!("Device flow poll error: {}", e);
 906                    let error_string = e.to_string();
 907                    flow_state
 908                        .update(cx, |state, cx| {
 909                            state.set_status(
 910                                OAuthDeviceFlowStatus::Failed(error_string.clone()),
 911                                cx,
 912                            );
 913                        })
 914                        .log_err();
 915                    this_handle
 916                        .update(cx, |this, cx| {
 917                            this.oauth_error = Some(error_string);
 918                            cx.notify();
 919                        })
 920                        .log_err();
 921                }
 922            };
 923
 924            this_handle
 925                .update(cx, |this, cx| {
 926                    this.oauth_in_progress = false;
 927                    cx.notify();
 928                })
 929                .log_err();
 930        })
 931        .detach();
 932    }
 933
 934    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 935        self.state.read(cx).is_authenticated
 936    }
 937
 938    fn has_oauth_config(&self) -> bool {
 939        self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
 940    }
 941
 942    fn oauth_config(&self) -> Option<&OAuthConfig> {
 943        self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
 944    }
 945
 946    fn has_api_key_config(&self) -> bool {
 947        // API key is available if there's a credential_label or no oauth-only config
 948        self.auth_config
 949            .as_ref()
 950            .map(|c| c.credential_label.is_some() || c.oauth.is_none())
 951            .unwrap_or(true)
 952    }
 953
 954    fn is_edit_prediction_mode(&self) -> bool {
 955        self.target_agent == ConfigurationViewTargetAgent::EditPrediction
 956    }
 957
 958    fn render_for_edit_prediction(
 959        &mut self,
 960        _window: &mut Window,
 961        cx: &mut Context<Self>,
 962    ) -> impl IntoElement {
 963        let is_loading = self.loading_settings || self.loading_credentials;
 964        let is_authenticated = self.is_authenticated(cx);
 965        let has_oauth = self.has_oauth_config();
 966
 967        // Helper to create the horizontal container layout matching Copilot
 968        let container = |description: SharedString, action: AnyElement| {
 969            h_flex()
 970                .pt_2p5()
 971                .w_full()
 972                .justify_between()
 973                .child(
 974                    v_flex()
 975                        .w_full()
 976                        .max_w_1_2()
 977                        .child(Label::new("Authenticate To Use"))
 978                        .child(
 979                            Label::new(description)
 980                                .color(Color::Muted)
 981                                .size(LabelSize::Small),
 982                        ),
 983                )
 984                .child(action)
 985        };
 986
 987        // Get the description from OAuth config or use a default
 988        let oauth_config = self.oauth_config();
 989        let description: SharedString = oauth_config
 990            .and_then(|c| c.sign_in_description.clone())
 991            .unwrap_or_else(|| "Sign in to authenticate with this provider.".to_string())
 992            .into();
 993
 994        if is_loading {
 995            return container(
 996                description,
 997                Button::new("loading", "Loading...")
 998                    .style(ButtonStyle::Outlined)
 999                    .disabled(true)
1000                    .into_any_element(),
1001            )
1002            .into_any_element();
1003        }
1004
1005        // If authenticated, show the configured card
1006        if is_authenticated {
1007            let (status_label, button_label) = if has_oauth {
1008                ("Authorized", "Sign Out")
1009            } else {
1010                ("API key configured", "Reset Key")
1011            };
1012
1013            return ConfiguredApiCard::new(status_label)
1014                .button_label(button_label)
1015                .on_click(cx.listener(|this, _, window, cx| {
1016                    this.reset_api_key(window, cx);
1017                }))
1018                .into_any_element();
1019        }
1020
1021        // Not authenticated - show sign in button
1022        if has_oauth {
1023            let button_label = oauth_config
1024                .and_then(|c| c.sign_in_button_label.clone())
1025                .unwrap_or_else(|| "Sign In".to_string());
1026            let button_icon = oauth_config
1027                .and_then(|c| c.sign_in_button_icon.as_ref())
1028                .and_then(|icon_name| match icon_name.as_str() {
1029                    "github" => Some(ui::IconName::Github),
1030                    _ => None,
1031                });
1032
1033            let oauth_in_progress = self.oauth_in_progress;
1034
1035            let mut button = Button::new("oauth-sign-in", button_label)
1036                .size(ButtonSize::Medium)
1037                .style(ButtonStyle::Outlined)
1038                .disabled(oauth_in_progress)
1039                .on_click(cx.listener(|this, _, window, cx| {
1040                    this.start_oauth_sign_in(window, cx);
1041                }));
1042
1043            if let Some(icon) = button_icon {
1044                button = button
1045                    .icon(icon)
1046                    .icon_position(ui::IconPosition::Start)
1047                    .icon_size(ui::IconSize::Small)
1048                    .icon_color(Color::Muted);
1049            }
1050
1051            return container(description, button.into_any_element()).into_any_element();
1052        }
1053
1054        // Fallback for API key only providers - show a simple message
1055        container(
1056            description,
1057            Button::new("configure", "Configure")
1058                .size(ButtonSize::Medium)
1059                .style(ButtonStyle::Outlined)
1060                .disabled(true)
1061                .into_any_element(),
1062        )
1063        .into_any_element()
1064    }
1065}
1066
1067impl Render for ExtensionProviderConfigurationView {
1068    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1069        if self.is_edit_prediction_mode() {
1070            return self
1071                .render_for_edit_prediction(window, cx)
1072                .into_any_element();
1073        }
1074
1075        let is_loading = self.loading_settings || self.loading_credentials;
1076        let is_authenticated = self.is_authenticated(cx);
1077        let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
1078        let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
1079        let has_oauth = self.has_oauth_config();
1080        let has_api_key = self.has_api_key_config();
1081
1082        if is_loading {
1083            return h_flex()
1084                .gap_2()
1085                .child(
1086                    h_flex()
1087                        .w_2()
1088                        .child(SpinnerLabel::sand().size(LabelSize::Small)),
1089                )
1090                .child(LoadingLabel::new("Loading").size(LabelSize::Small))
1091                .into_any_element();
1092        }
1093
1094        let mut content = v_flex().size_full().gap_2();
1095
1096        if let Some(markdown) = &self.settings_markdown {
1097            content = content.text_sm().child(MarkdownElement::new(
1098                markdown.clone(),
1099                markdown_styles(window, cx),
1100            ));
1101        }
1102
1103        if let Some(auth_config) = &self.auth_config {
1104            if let Some(env_vars) = &auth_config.env_vars {
1105                for env_var_name in env_vars {
1106                    let is_allowed = allowed_env_vars.contains(env_var_name);
1107                    let checkbox_label =
1108                        format!("Read API key from {} environment variable.", env_var_name);
1109                    let env_var_for_click = env_var_name.clone();
1110
1111                    content = content.child(
1112                        Checkbox::new(
1113                            SharedString::from(format!("env-var-{}", env_var_name)),
1114                            if is_allowed {
1115                                ToggleState::Selected
1116                            } else {
1117                                ToggleState::Unselected
1118                            },
1119                        )
1120                        .label(checkbox_label)
1121                        .on_click(cx.listener(
1122                            move |this, _, _window, cx| {
1123                                this.toggle_env_var_permission(env_var_for_click.clone(), cx);
1124                            },
1125                        )),
1126                    );
1127                }
1128
1129                if let Some(used_var) = &env_var_name_used {
1130                    content = content.child(
1131                        ConfiguredApiCard::new(format!(
1132                            "API key set in {} environment variable",
1133                            used_var
1134                        ))
1135                        .tooltip_label(format!(
1136                            "To reset this API key, unset the {} environment variable.",
1137                            used_var
1138                        ))
1139                        .disabled(true),
1140                    );
1141
1142                    return content.into_any_element();
1143                }
1144            }
1145        }
1146
1147        if is_authenticated && env_var_name_used.is_none() {
1148            let (status_label, button_label) = if has_oauth && !has_api_key {
1149                ("Signed in", "Sign Out")
1150            } else {
1151                ("API key configured", "Reset Key")
1152            };
1153
1154            content = content.child(
1155                ConfiguredApiCard::new(status_label)
1156                    .button_label(button_label)
1157                    .on_click(cx.listener(|this, _, window, cx| {
1158                        this.reset_api_key(window, cx);
1159                    })),
1160            );
1161
1162            return content.into_any_element();
1163        }
1164
1165        // Not authenticated - show available auth options
1166        if env_var_name_used.is_none() {
1167            // Render OAuth sign-in button if configured
1168            if has_oauth {
1169                let oauth_config = self.oauth_config();
1170                let button_label = oauth_config
1171                    .and_then(|c| c.sign_in_button_label.clone())
1172                    .unwrap_or_else(|| "Sign In".to_string());
1173                let button_icon = oauth_config
1174                    .and_then(|c| c.sign_in_button_icon.as_ref())
1175                    .and_then(|icon_name| match icon_name.as_str() {
1176                        "github" => Some(ui::IconName::Github),
1177                        _ => None,
1178                    });
1179
1180                let oauth_in_progress = self.oauth_in_progress;
1181
1182                let oauth_error = self.oauth_error.clone();
1183
1184                let mut button = Button::new("oauth-sign-in", button_label)
1185                    .full_width()
1186                    .style(ButtonStyle::Outlined)
1187                    .disabled(oauth_in_progress)
1188                    .on_click(cx.listener(|this, _, window, cx| {
1189                        this.start_oauth_sign_in(window, cx);
1190                    }));
1191                if let Some(icon) = button_icon {
1192                    button = button
1193                        .icon(icon)
1194                        .icon_position(IconPosition::Start)
1195                        .icon_size(IconSize::Small)
1196                        .icon_color(Color::Muted);
1197                }
1198
1199                content = content.child(
1200                    v_flex()
1201                        .gap_2()
1202                        .child(button)
1203                        .when(oauth_in_progress, |this| {
1204                            this.child(
1205                                Label::new("Sign-in in progress...")
1206                                    .size(LabelSize::Small)
1207                                    .color(Color::Muted),
1208                            )
1209                        })
1210                        .when_some(oauth_error, |this, error| {
1211                            this.child(
1212                                v_flex()
1213                                    .gap_1()
1214                                    .child(
1215                                        h_flex()
1216                                            .gap_2()
1217                                            .child(
1218                                                Icon::new(IconName::Warning)
1219                                                    .color(Color::Error)
1220                                                    .size(IconSize::Small),
1221                                            )
1222                                            .child(
1223                                                Label::new("Authentication failed")
1224                                                    .color(Color::Error)
1225                                                    .size(LabelSize::Small),
1226                                            ),
1227                                    )
1228                                    .child(
1229                                        div().pl_6().child(
1230                                            Label::new(error)
1231                                                .color(Color::Error)
1232                                                .size(LabelSize::Small),
1233                                        ),
1234                                    ),
1235                            )
1236                        }),
1237                );
1238            }
1239
1240            // Render API key input if configured (and we have both options, show a separator)
1241            if has_api_key {
1242                if has_oauth {
1243                    content = content.child(
1244                        h_flex()
1245                            .gap_2()
1246                            .items_center()
1247                            .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant))
1248                            .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1249                            .child(div().h_px().flex_1().bg(cx.theme().colors().border_variant)),
1250                    );
1251                }
1252
1253                content = content.child(
1254                    div()
1255                        .on_action(cx.listener(Self::save_api_key))
1256                        .child(self.api_key_editor.clone()),
1257                );
1258            }
1259        }
1260
1261        if self.extension_provider_id == "openai" {
1262            content = content.child(
1263                h_flex()
1264                    .gap_1()
1265                    .child(
1266                        Icon::new(IconName::Info)
1267                            .size(IconSize::XSmall)
1268                            .color(Color::Muted),
1269                    )
1270                    .child(
1271                        Label::new("Zed also supports OpenAI-compatible models.")
1272                            .size(LabelSize::Small)
1273                            .color(Color::Muted),
1274                    )
1275                    .child(
1276                        ButtonLink::new(
1277                            "Learn More",
1278                            "https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers",
1279                        )
1280                        .label_size(LabelSize::Small),
1281                    ),
1282            );
1283        }
1284
1285        content.into_any_element()
1286    }
1287}
1288
1289impl Focusable for ExtensionProviderConfigurationView {
1290    fn focus_handle(&self, cx: &App) -> FocusHandle {
1291        self.api_key_editor.read(cx).focus_handle(cx)
1292    }
1293}
1294
1295/// A popup window for OAuth device flow, similar to CopilotCodeVerification.
1296/// This is used when in edit prediction mode to avoid moving the settings panel behind.
1297pub struct OAuthCodeVerificationWindow {
1298    config: OAuthDeviceFlowModalConfig,
1299    status: OAuthDeviceFlowStatus,
1300    connect_clicked: bool,
1301    focus_handle: FocusHandle,
1302    _subscription: Option<Subscription>,
1303}
1304
1305impl Focusable for OAuthCodeVerificationWindow {
1306    fn focus_handle(&self, _: &App) -> FocusHandle {
1307        self.focus_handle.clone()
1308    }
1309}
1310
1311impl EventEmitter<DismissEvent> for OAuthCodeVerificationWindow {}
1312
1313impl OAuthCodeVerificationWindow {
1314    pub fn new(
1315        config: OAuthDeviceFlowModalConfig,
1316        state: Entity<OAuthDeviceFlowState>,
1317        window: &mut Window,
1318        cx: &mut Context<Self>,
1319    ) -> Self {
1320        window.on_window_should_close(cx, |window, cx| {
1321            if let Some(this) = window.root::<OAuthCodeVerificationWindow>().flatten() {
1322                this.update(cx, |_, cx| {
1323                    cx.emit(DismissEvent);
1324                });
1325            }
1326            true
1327        });
1328        cx.subscribe_in(
1329            &cx.entity(),
1330            window,
1331            |_, _, _: &DismissEvent, window, _cx| {
1332                window.remove_window();
1333            },
1334        )
1335        .detach();
1336
1337        let subscription = cx.observe(&state, |this, state, cx| {
1338            let status = state.read(cx).status.clone();
1339            this.status = status;
1340            cx.notify();
1341        });
1342
1343        Self {
1344            config,
1345            status: state.read(cx).status.clone(),
1346            connect_clicked: false,
1347            focus_handle: cx.focus_handle(),
1348            _subscription: Some(subscription),
1349        }
1350    }
1351
1352    fn render_icon(&self, cx: &mut Context<Self>) -> impl IntoElement {
1353        let icon_color = Color::Custom(cx.theme().colors().icon);
1354        let icon_size = rems(2.5);
1355        let plus_size = rems(0.875);
1356        let plus_color = cx.theme().colors().icon.opacity(0.5);
1357
1358        if let Some(icon_path) = &self.config.icon_path {
1359            h_flex()
1360                .gap_2()
1361                .items_center()
1362                .child(
1363                    Icon::from_external_svg(icon_path.clone())
1364                        .size(IconSize::Custom(icon_size))
1365                        .color(icon_color),
1366                )
1367                .child(
1368                    gpui::svg()
1369                        .size(plus_size)
1370                        .path("icons/plus.svg")
1371                        .text_color(plus_color),
1372                )
1373                .child(Vector::new(VectorName::ZedLogo, icon_size, icon_size).color(icon_color))
1374                .into_any_element()
1375        } else {
1376            Vector::new(VectorName::ZedLogo, icon_size, icon_size)
1377                .color(icon_color)
1378                .into_any_element()
1379        }
1380    }
1381
1382    fn render_device_code(&self, cx: &mut Context<Self>) -> impl IntoElement {
1383        let user_code = self.config.user_code.clone();
1384        let copied = cx
1385            .read_from_clipboard()
1386            .map(|item| item.text().as_ref() == Some(&user_code))
1387            .unwrap_or(false);
1388        let user_code_for_click = user_code.clone();
1389
1390        ButtonLike::new("copy-button")
1391            .full_width()
1392            .style(ButtonStyle::Tinted(ui::TintColor::Accent))
1393            .size(ButtonSize::Medium)
1394            .child(
1395                h_flex()
1396                    .w_full()
1397                    .p_1()
1398                    .justify_between()
1399                    .child(Label::new(user_code))
1400                    .child(Label::new(if copied { "Copied!" } else { "Copy" })),
1401            )
1402            .on_click(move |_, window, cx| {
1403                cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone()));
1404                window.refresh();
1405            })
1406    }
1407
1408    fn render_prompting_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1409        let connect_button_label: String = if self.connect_clicked {
1410            "Waiting for connection…".to_string()
1411        } else {
1412            self.config.connect_button_label.clone()
1413        };
1414        let verification_url = self.config.verification_url.clone();
1415
1416        v_flex()
1417            .flex_1()
1418            .gap_2p5()
1419            .items_center()
1420            .text_center()
1421            .child(Headline::new(self.config.headline.clone()).size(HeadlineSize::Large))
1422            .child(Label::new(self.config.description.clone()).color(Color::Muted))
1423            .child(self.render_device_code(cx))
1424            .child(
1425                Label::new("Paste this code after clicking the button below.").color(Color::Muted),
1426            )
1427            .child(
1428                v_flex()
1429                    .w_full()
1430                    .gap_1()
1431                    .child(
1432                        Button::new("connect-button", connect_button_label)
1433                            .full_width()
1434                            .style(ButtonStyle::Outlined)
1435                            .size(ButtonSize::Medium)
1436                            .on_click(cx.listener(move |this, _, _window, cx| {
1437                                cx.open_url(&verification_url);
1438                                this.connect_clicked = true;
1439                            })),
1440                    )
1441                    .child(
1442                        Button::new("cancel-button", "Cancel")
1443                            .full_width()
1444                            .size(ButtonSize::Medium)
1445                            .on_click(cx.listener(|_, _, _, cx| {
1446                                cx.emit(DismissEvent);
1447                            })),
1448                    ),
1449            )
1450    }
1451
1452    fn render_authorized_modal(&self, cx: &mut Context<Self>) -> impl IntoElement {
1453        v_flex()
1454            .gap_2()
1455            .text_center()
1456            .justify_center()
1457            .child(Headline::new(self.config.success_headline.clone()).size(HeadlineSize::Large))
1458            .child(Label::new(self.config.success_message.clone()).color(Color::Muted))
1459            .child(
1460                Button::new("done-button", "Done")
1461                    .full_width()
1462                    .style(ButtonStyle::Outlined)
1463                    .size(ButtonSize::Medium)
1464                    .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1465            )
1466    }
1467
1468    fn render_failed_modal(&self, error: &str, cx: &mut Context<Self>) -> impl IntoElement {
1469        v_flex()
1470            .gap_2()
1471            .text_center()
1472            .justify_center()
1473            .child(Headline::new("Authorization Failed").size(HeadlineSize::Large))
1474            .child(Label::new(error.to_string()).color(Color::Error))
1475            .child(
1476                Button::new("close-button", "Close")
1477                    .full_width()
1478                    .size(ButtonSize::Medium)
1479                    .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
1480            )
1481    }
1482}
1483
1484impl Render for OAuthCodeVerificationWindow {
1485    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1486        let prompt = match &self.status {
1487            OAuthDeviceFlowStatus::Prompting | OAuthDeviceFlowStatus::WaitingForAuthorization => {
1488                self.render_prompting_modal(cx).into_any_element()
1489            }
1490            OAuthDeviceFlowStatus::Authorized => {
1491                self.render_authorized_modal(cx).into_any_element()
1492            }
1493            OAuthDeviceFlowStatus::Failed(error) => {
1494                self.render_failed_modal(error, cx).into_any_element()
1495            }
1496        };
1497
1498        v_flex()
1499            .id("oauth_code_verification")
1500            .track_focus(&self.focus_handle(cx))
1501            .size_full()
1502            .px_4()
1503            .py_8()
1504            .gap_2()
1505            .items_center()
1506            .justify_center()
1507            .elevation_3(cx)
1508            .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
1509                cx.emit(DismissEvent);
1510            }))
1511            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _| {
1512                window.focus(&this.focus_handle);
1513            }))
1514            .child(self.render_icon(cx))
1515            .child(prompt)
1516    }
1517}
1518
1519fn markdown_styles(window: &Window, cx: &App) -> MarkdownStyle {
1520    let settings = ThemeSettings::get_global(cx);
1521    let colors = cx.theme().colors();
1522
1523    let mut text_style = window.text_style();
1524    text_style.refine(&TextStyleRefinement {
1525        font_family: Some(settings.ui_font.family.clone()),
1526        font_fallbacks: settings.ui_font.fallbacks.clone(),
1527        font_features: Some(settings.ui_font.features.clone()),
1528        font_size: Some(settings.ui_font_size(cx).into()),
1529        line_height: Some(relative(1.5)),
1530        color: Some(colors.text_muted),
1531        ..Default::default()
1532    });
1533
1534    MarkdownStyle {
1535        base_text_style: text_style.clone(),
1536        syntax: cx.theme().syntax().clone(),
1537        selection_background_color: colors.element_selection_background,
1538        heading_level_styles: Some(HeadingLevelStyles {
1539            h1: Some(TextStyleRefinement {
1540                font_size: Some(rems(1.15).into()),
1541                ..Default::default()
1542            }),
1543            h2: Some(TextStyleRefinement {
1544                font_size: Some(rems(1.1).into()),
1545                ..Default::default()
1546            }),
1547            h3: Some(TextStyleRefinement {
1548                font_size: Some(rems(1.05).into()),
1549                ..Default::default()
1550            }),
1551            h4: Some(TextStyleRefinement {
1552                font_size: Some(rems(1.).into()),
1553                ..Default::default()
1554            }),
1555            h5: Some(TextStyleRefinement {
1556                font_size: Some(rems(0.95).into()),
1557                ..Default::default()
1558            }),
1559            h6: Some(TextStyleRefinement {
1560                font_size: Some(rems(0.875).into()),
1561                ..Default::default()
1562            }),
1563        }),
1564        inline_code: TextStyleRefinement {
1565            font_family: Some(settings.buffer_font.family.clone()),
1566            font_fallbacks: settings.buffer_font.fallbacks.clone(),
1567            font_features: Some(settings.buffer_font.features.clone()),
1568            font_size: Some(settings.buffer_font_size(cx).into()),
1569            background_color: Some(colors.editor_foreground.opacity(0.08)),
1570            ..Default::default()
1571        },
1572        link: TextStyleRefinement {
1573            background_color: Some(colors.editor_foreground.opacity(0.025)),
1574            color: Some(colors.text_accent),
1575            underline: Some(UnderlineStyle {
1576                color: Some(colors.text_accent.opacity(0.5)),
1577                thickness: px(1.),
1578                ..Default::default()
1579            }),
1580            ..Default::default()
1581        },
1582        ..Default::default()
1583    }
1584}
1585
1586/// An extension-based language model.
1587pub struct ExtensionLanguageModel {
1588    extension: WasmExtension,
1589    model_info: LlmModelInfo,
1590    provider_id: LanguageModelProviderId,
1591    provider_name: LanguageModelProviderName,
1592    provider_info: LlmProviderInfo,
1593}
1594
1595impl LanguageModel for ExtensionLanguageModel {
1596    fn id(&self) -> LanguageModelId {
1597        LanguageModelId::from(self.model_info.id.clone())
1598    }
1599
1600    fn name(&self) -> LanguageModelName {
1601        LanguageModelName::from(self.model_info.name.clone())
1602    }
1603
1604    fn provider_id(&self) -> LanguageModelProviderId {
1605        self.provider_id.clone()
1606    }
1607
1608    fn provider_name(&self) -> LanguageModelProviderName {
1609        self.provider_name.clone()
1610    }
1611
1612    fn telemetry_id(&self) -> String {
1613        format!("extension-{}", self.model_info.id)
1614    }
1615
1616    fn supports_images(&self) -> bool {
1617        self.model_info.capabilities.supports_images
1618    }
1619
1620    fn supports_tools(&self) -> bool {
1621        self.model_info.capabilities.supports_tools
1622    }
1623
1624    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1625        match choice {
1626            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1627            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1628            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1629        }
1630    }
1631
1632    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1633        match self.model_info.capabilities.tool_input_format {
1634            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1635            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1636        }
1637    }
1638
1639    fn max_token_count(&self) -> u64 {
1640        self.model_info.max_token_count
1641    }
1642
1643    fn max_output_tokens(&self) -> Option<u64> {
1644        self.model_info.max_output_tokens
1645    }
1646
1647    fn count_tokens(
1648        &self,
1649        request: LanguageModelRequest,
1650        cx: &App,
1651    ) -> BoxFuture<'static, Result<u64>> {
1652        let extension = self.extension.clone();
1653        let provider_id = self.provider_info.id.clone();
1654        let model_id = self.model_info.id.clone();
1655
1656        let wit_request = convert_request_to_wit(request);
1657
1658        cx.background_spawn(async move {
1659            extension
1660                .call({
1661                    let provider_id = provider_id.clone();
1662                    let model_id = model_id.clone();
1663                    let wit_request = wit_request.clone();
1664                    |ext, store| {
1665                        async move {
1666                            let count = ext
1667                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1668                                .await?
1669                                .map_err(|e| anyhow!("{}", e))?;
1670                            Ok(count)
1671                        }
1672                        .boxed()
1673                    }
1674                })
1675                .await?
1676        })
1677        .boxed()
1678    }
1679
1680    fn stream_completion(
1681        &self,
1682        request: LanguageModelRequest,
1683        _cx: &AsyncApp,
1684    ) -> BoxFuture<
1685        'static,
1686        Result<
1687            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1688            LanguageModelCompletionError,
1689        >,
1690    > {
1691        let extension = self.extension.clone();
1692        let provider_id = self.provider_info.id.clone();
1693        let model_id = self.model_info.id.clone();
1694
1695        let wit_request = convert_request_to_wit(request);
1696
1697        async move {
1698            // Start the stream
1699            let stream_id_result = extension
1700                .call({
1701                    let provider_id = provider_id.clone();
1702                    let model_id = model_id.clone();
1703                    let wit_request = wit_request.clone();
1704                    |ext, store| {
1705                        async move {
1706                            let id = ext
1707                                .call_llm_stream_completion_start(
1708                                    store,
1709                                    &provider_id,
1710                                    &model_id,
1711                                    &wit_request,
1712                                )
1713                                .await?
1714                                .map_err(|e| anyhow!("{}", e))?;
1715                            Ok(id)
1716                        }
1717                        .boxed()
1718                    }
1719                })
1720                .await;
1721
1722            let stream_id = stream_id_result
1723                .map_err(LanguageModelCompletionError::Other)?
1724                .map_err(LanguageModelCompletionError::Other)?;
1725
1726            // Create a stream that polls for events
1727            let stream = futures::stream::unfold(
1728                (extension.clone(), stream_id, false),
1729                move |(extension, stream_id, done)| async move {
1730                    if done {
1731                        return None;
1732                    }
1733
1734                    let result = extension
1735                        .call({
1736                            let stream_id = stream_id.clone();
1737                            |ext, store| {
1738                                async move {
1739                                    let event = ext
1740                                        .call_llm_stream_completion_next(store, &stream_id)
1741                                        .await?
1742                                        .map_err(|e| anyhow!("{}", e))?;
1743                                    Ok(event)
1744                                }
1745                                .boxed()
1746                            }
1747                        })
1748                        .await
1749                        .and_then(|inner| inner);
1750
1751                    match result {
1752                        Ok(Some(event)) => {
1753                            let converted = convert_completion_event(event);
1754                            let is_done =
1755                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1756                            Some((converted, (extension, stream_id, is_done)))
1757                        }
1758                        Ok(None) => {
1759                            // Stream complete, close it
1760                            let _ = extension
1761                                .call({
1762                                    let stream_id = stream_id.clone();
1763                                    |ext, store| {
1764                                        async move {
1765                                            ext.call_llm_stream_completion_close(store, &stream_id)
1766                                                .await?;
1767                                            Ok::<(), anyhow::Error>(())
1768                                        }
1769                                        .boxed()
1770                                    }
1771                                })
1772                                .await;
1773                            None
1774                        }
1775                        Err(e) => Some((
1776                            Err(LanguageModelCompletionError::Other(e)),
1777                            (extension, stream_id, true),
1778                        )),
1779                    }
1780                },
1781            );
1782
1783            Ok(stream.boxed())
1784        }
1785        .boxed()
1786    }
1787
1788    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1789        // Extensions can implement this via llm_cache_configuration
1790        None
1791    }
1792}
1793
1794fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1795    use language_model::{MessageContent, Role};
1796
1797    let messages: Vec<LlmRequestMessage> = request
1798        .messages
1799        .into_iter()
1800        .map(|msg| {
1801            let role = match msg.role {
1802                Role::User => LlmMessageRole::User,
1803                Role::Assistant => LlmMessageRole::Assistant,
1804                Role::System => LlmMessageRole::System,
1805            };
1806
1807            let content: Vec<LlmMessageContent> = msg
1808                .content
1809                .into_iter()
1810                .map(|c| match c {
1811                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1812                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1813                        source: image.source.to_string(),
1814                        width: Some(image.size.width.0 as u32),
1815                        height: Some(image.size.height.0 as u32),
1816                    }),
1817                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1818                        id: tool_use.id.to_string(),
1819                        name: tool_use.name.to_string(),
1820                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1821                        is_input_complete: tool_use.is_input_complete,
1822                        thought_signature: tool_use.thought_signature,
1823                    }),
1824                    MessageContent::ToolResult(tool_result) => {
1825                        let content = match tool_result.content {
1826                            language_model::LanguageModelToolResultContent::Text(text) => {
1827                                LlmToolResultContent::Text(text.to_string())
1828                            }
1829                            language_model::LanguageModelToolResultContent::Image(image) => {
1830                                LlmToolResultContent::Image(LlmImageData {
1831                                    source: image.source.to_string(),
1832                                    width: Some(image.size.width.0 as u32),
1833                                    height: Some(image.size.height.0 as u32),
1834                                })
1835                            }
1836                        };
1837                        LlmMessageContent::ToolResult(LlmToolResult {
1838                            tool_use_id: tool_result.tool_use_id.to_string(),
1839                            tool_name: tool_result.tool_name.to_string(),
1840                            is_error: tool_result.is_error,
1841                            content,
1842                        })
1843                    }
1844                    MessageContent::Thinking { text, signature } => {
1845                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1846                    }
1847                    MessageContent::RedactedThinking(data) => {
1848                        LlmMessageContent::RedactedThinking(data)
1849                    }
1850                })
1851                .collect();
1852
1853            LlmRequestMessage {
1854                role,
1855                content,
1856                cache: msg.cache,
1857            }
1858        })
1859        .collect();
1860
1861    let tools: Vec<LlmToolDefinition> = request
1862        .tools
1863        .into_iter()
1864        .map(|tool| LlmToolDefinition {
1865            name: tool.name,
1866            description: tool.description,
1867            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1868        })
1869        .collect();
1870
1871    let tool_choice = request.tool_choice.map(|tc| match tc {
1872        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1873        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1874        LanguageModelToolChoice::None => LlmToolChoice::None,
1875    });
1876
1877    LlmCompletionRequest {
1878        messages,
1879        tools,
1880        tool_choice,
1881        stop_sequences: request.stop,
1882        temperature: request.temperature,
1883        thinking_allowed: false,
1884        max_tokens: None,
1885    }
1886}
1887
1888fn convert_completion_event(
1889    event: LlmCompletionEvent,
1890) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1891    match event {
1892        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1893            message_id: String::new(),
1894        }),
1895        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1896        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1897            text: thinking.text,
1898            signature: thinking.signature,
1899        }),
1900        LlmCompletionEvent::RedactedThinking(data) => {
1901            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1902        }
1903        LlmCompletionEvent::ToolUse(tool_use) => {
1904            let raw_input = tool_use.input.clone();
1905            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1906            Ok(LanguageModelCompletionEvent::ToolUse(
1907                LanguageModelToolUse {
1908                    id: LanguageModelToolUseId::from(tool_use.id),
1909                    name: tool_use.name.into(),
1910                    raw_input,
1911                    input,
1912                    is_input_complete: tool_use.is_input_complete,
1913                    thought_signature: tool_use.thought_signature,
1914                },
1915            ))
1916        }
1917        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1918            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1919                id: LanguageModelToolUseId::from(error.id),
1920                tool_name: error.tool_name.into(),
1921                raw_input: error.raw_input.into(),
1922                json_parse_error: error.error,
1923            })
1924        }
1925        LlmCompletionEvent::Stop(reason) => {
1926            let stop_reason = match reason {
1927                LlmStopReason::EndTurn => StopReason::EndTurn,
1928                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1929                LlmStopReason::ToolUse => StopReason::ToolUse,
1930                LlmStopReason::Refusal => StopReason::Refusal,
1931            };
1932            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1933        }
1934        LlmCompletionEvent::Usage(usage) => {
1935            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1936                input_tokens: usage.input_tokens,
1937                output_tokens: usage.output_tokens,
1938                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1939                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1940            }))
1941        }
1942        LlmCompletionEvent::ReasoningDetails(json) => {
1943            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1944                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1945            ))
1946        }
1947    }
1948}