llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::wasm_host::WasmExtension;
   3
   4use crate::wasm_host::wit::{
   5    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   6    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
   7    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
   8    LlmToolUse,
   9};
  10use anyhow::{Result, anyhow};
  11use credentials_provider::CredentialsProvider;
  12use editor::Editor;
  13use extension::{LanguageModelAuthConfig, OAuthConfig};
  14use futures::future::BoxFuture;
  15use futures::stream::BoxStream;
  16use futures::{FutureExt, StreamExt};
  17use gpui::Focusable;
  18use gpui::{
  19    AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
  20    MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
  21};
  22use language_model::tool_schema::LanguageModelToolSchemaFormat;
  23use language_model::{
  24    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  25    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  27    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  28    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  29};
  30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  31use settings::Settings;
  32use std::sync::Arc;
  33use theme::ThemeSettings;
  34use ui::{Label, LabelSize, prelude::*};
  35use util::ResultExt as _;
  36
  37/// An extension-based language model provider.
  38pub struct ExtensionLanguageModelProvider {
  39    pub extension: WasmExtension,
  40    pub provider_info: LlmProviderInfo,
  41    icon_path: Option<SharedString>,
  42    auth_config: Option<LanguageModelAuthConfig>,
  43    state: Entity<ExtensionLlmProviderState>,
  44}
  45
  46pub struct ExtensionLlmProviderState {
  47    is_authenticated: bool,
  48    available_models: Vec<LlmModelInfo>,
  49    env_var_allowed: bool,
  50    api_key_from_env: bool,
  51}
  52
  53impl EventEmitter<()> for ExtensionLlmProviderState {}
  54
  55impl ExtensionLanguageModelProvider {
  56    pub fn new(
  57        extension: WasmExtension,
  58        provider_info: LlmProviderInfo,
  59        models: Vec<LlmModelInfo>,
  60        is_authenticated: bool,
  61        icon_path: Option<SharedString>,
  62        auth_config: Option<LanguageModelAuthConfig>,
  63        cx: &mut App,
  64    ) -> Self {
  65        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  66        let env_var_allowed = ExtensionSettings::get_global(cx)
  67            .allowed_env_var_providers
  68            .contains(provider_id_string.as_str());
  69
  70        let (is_authenticated, api_key_from_env) =
  71            if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
  72                let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
  73                if let Ok(value) = std::env::var(env_var_name) {
  74                    if !value.is_empty() {
  75                        (true, true)
  76                    } else {
  77                        (is_authenticated, false)
  78                    }
  79                } else {
  80                    (is_authenticated, false)
  81                }
  82            } else {
  83                (is_authenticated, false)
  84            };
  85
  86        let state = cx.new(|_| ExtensionLlmProviderState {
  87            is_authenticated,
  88            available_models: models,
  89            env_var_allowed,
  90            api_key_from_env,
  91        });
  92
  93        Self {
  94            extension,
  95            provider_info,
  96            icon_path,
  97            auth_config,
  98            state,
  99        }
 100    }
 101
 102    fn provider_id_string(&self) -> String {
 103        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 104    }
 105
 106    /// The credential key used for storing the API key in the system keychain.
 107    fn credential_key(&self) -> String {
 108        format!("extension-llm-{}", self.provider_id_string())
 109    }
 110}
 111
 112impl LanguageModelProvider for ExtensionLanguageModelProvider {
 113    fn id(&self) -> LanguageModelProviderId {
 114        LanguageModelProviderId::from(self.provider_id_string())
 115    }
 116
 117    fn name(&self) -> LanguageModelProviderName {
 118        LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
 119    }
 120
 121    fn icon(&self) -> ui::IconName {
 122        ui::IconName::ZedAssistant
 123    }
 124
 125    fn icon_path(&self) -> Option<SharedString> {
 126        self.icon_path.clone()
 127    }
 128
 129    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 130        let state = self.state.read(cx);
 131        state
 132            .available_models
 133            .iter()
 134            .find(|m| m.is_default)
 135            .or_else(|| state.available_models.first())
 136            .map(|model_info| {
 137                Arc::new(ExtensionLanguageModel {
 138                    extension: self.extension.clone(),
 139                    model_info: model_info.clone(),
 140                    provider_id: self.id(),
 141                    provider_name: self.name(),
 142                    provider_info: self.provider_info.clone(),
 143                }) as Arc<dyn LanguageModel>
 144            })
 145    }
 146
 147    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 148        let state = self.state.read(cx);
 149        state
 150            .available_models
 151            .iter()
 152            .find(|m| m.is_default_fast)
 153            .map(|model_info| {
 154                Arc::new(ExtensionLanguageModel {
 155                    extension: self.extension.clone(),
 156                    model_info: model_info.clone(),
 157                    provider_id: self.id(),
 158                    provider_name: self.name(),
 159                    provider_info: self.provider_info.clone(),
 160                }) as Arc<dyn LanguageModel>
 161            })
 162    }
 163
 164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 165        let state = self.state.read(cx);
 166        state
 167            .available_models
 168            .iter()
 169            .map(|model_info| {
 170                Arc::new(ExtensionLanguageModel {
 171                    extension: self.extension.clone(),
 172                    model_info: model_info.clone(),
 173                    provider_id: self.id(),
 174                    provider_name: self.name(),
 175                    provider_info: self.provider_info.clone(),
 176                }) as Arc<dyn LanguageModel>
 177            })
 178            .collect()
 179    }
 180
 181    fn is_authenticated(&self, cx: &App) -> bool {
 182        // First check cached state
 183        if self.state.read(cx).is_authenticated {
 184            return true;
 185        }
 186
 187        // Also check env var dynamically (in case migration happened after provider creation)
 188        if let Some(ref auth_config) = self.auth_config {
 189            if let Some(ref env_var_name) = auth_config.env_var {
 190                let provider_id_string = self.provider_id_string();
 191                let env_var_allowed = ExtensionSettings::get_global(cx)
 192                    .allowed_env_var_providers
 193                    .contains(provider_id_string.as_str());
 194
 195                if env_var_allowed {
 196                    if let Ok(value) = std::env::var(env_var_name) {
 197                        if !value.is_empty() {
 198                            return true;
 199                        }
 200                    }
 201                }
 202            }
 203        }
 204
 205        false
 206    }
 207
 208    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 209        let extension = self.extension.clone();
 210        let provider_id = self.provider_info.id.clone();
 211        let state = self.state.clone();
 212
 213        cx.spawn(async move |cx| {
 214            let result = extension
 215                .call(|extension, store| {
 216                    async move {
 217                        extension
 218                            .call_llm_provider_authenticate(store, &provider_id)
 219                            .await
 220                    }
 221                    .boxed()
 222                })
 223                .await;
 224
 225            match result {
 226                Ok(Ok(Ok(()))) => {
 227                    cx.update(|cx| {
 228                        state.update(cx, |state, _| {
 229                            state.is_authenticated = true;
 230                        });
 231                    })?;
 232                    Ok(())
 233                }
 234                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
 235                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
 236                Err(e) => Err(AuthenticateError::Other(e)),
 237            }
 238        })
 239    }
 240
 241    fn configuration_view(
 242        &self,
 243        _target_agent: ConfigurationViewTargetAgent,
 244        window: &mut Window,
 245        cx: &mut App,
 246    ) -> AnyView {
 247        let credential_key = self.credential_key();
 248        let extension = self.extension.clone();
 249        let extension_provider_id = self.provider_info.id.clone();
 250        let full_provider_id = self.provider_id_string();
 251        let state = self.state.clone();
 252        let auth_config = self.auth_config.clone();
 253
 254        cx.new(|cx| {
 255            ExtensionProviderConfigurationView::new(
 256                credential_key,
 257                extension,
 258                extension_provider_id,
 259                full_provider_id,
 260                auth_config,
 261                state,
 262                window,
 263                cx,
 264            )
 265        })
 266        .into()
 267    }
 268
 269    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 270        let extension = self.extension.clone();
 271        let provider_id = self.provider_info.id.clone();
 272        let state = self.state.clone();
 273        let credential_key = self.credential_key();
 274
 275        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 276
 277        cx.spawn(async move |cx| {
 278            // Delete from system keychain
 279            credentials_provider
 280                .delete_credentials(&credential_key, cx)
 281                .await
 282                .log_err();
 283
 284            // Call extension's reset_credentials
 285            let result = extension
 286                .call(|extension, store| {
 287                    async move {
 288                        extension
 289                            .call_llm_provider_reset_credentials(store, &provider_id)
 290                            .await
 291                    }
 292                    .boxed()
 293                })
 294                .await;
 295
 296            // Update state
 297            cx.update(|cx| {
 298                state.update(cx, |state, _| {
 299                    state.is_authenticated = false;
 300                });
 301            })?;
 302
 303            match result {
 304                Ok(Ok(Ok(()))) => Ok(()),
 305                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 306                Ok(Err(e)) => Err(e),
 307                Err(e) => Err(e),
 308            }
 309        })
 310    }
 311}
 312
 313impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 314    type ObservableEntity = ExtensionLlmProviderState;
 315
 316    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 317        Some(self.state.clone())
 318    }
 319
 320    fn subscribe<T: 'static>(
 321        &self,
 322        cx: &mut Context<T>,
 323        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 324    ) -> Option<Subscription> {
 325        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 326    }
 327}
 328
 329/// Configuration view for extension-based LLM providers.
 330struct ExtensionProviderConfigurationView {
 331    credential_key: String,
 332    extension: WasmExtension,
 333    extension_provider_id: String,
 334    full_provider_id: String,
 335    auth_config: Option<LanguageModelAuthConfig>,
 336    state: Entity<ExtensionLlmProviderState>,
 337    settings_markdown: Option<Entity<Markdown>>,
 338    api_key_editor: Entity<Editor>,
 339    loading_settings: bool,
 340    loading_credentials: bool,
 341    oauth_in_progress: bool,
 342    oauth_error: Option<String>,
 343    device_user_code: Option<String>,
 344    _subscriptions: Vec<Subscription>,
 345}
 346
 347impl ExtensionProviderConfigurationView {
 348    fn new(
 349        credential_key: String,
 350        extension: WasmExtension,
 351        extension_provider_id: String,
 352        full_provider_id: String,
 353        auth_config: Option<LanguageModelAuthConfig>,
 354        state: Entity<ExtensionLlmProviderState>,
 355        window: &mut Window,
 356        cx: &mut Context<Self>,
 357    ) -> Self {
 358        // Subscribe to state changes
 359        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 360            cx.notify();
 361        });
 362
 363        // Create API key editor
 364        let api_key_editor = cx.new(|cx| {
 365            let mut editor = Editor::single_line(window, cx);
 366            editor.set_placeholder_text("Enter API key...", window, cx);
 367            editor
 368        });
 369
 370        let mut this = Self {
 371            credential_key,
 372            extension,
 373            extension_provider_id,
 374            full_provider_id,
 375            auth_config,
 376            state,
 377            settings_markdown: None,
 378            api_key_editor,
 379            loading_settings: true,
 380            loading_credentials: true,
 381            oauth_in_progress: false,
 382            oauth_error: None,
 383            device_user_code: None,
 384            _subscriptions: vec![state_subscription],
 385        };
 386
 387        // Load settings text from extension
 388        this.load_settings_text(cx);
 389
 390        // Load existing credentials
 391        this.load_credentials(cx);
 392
 393        this
 394    }
 395
 396    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 397        let extension = self.extension.clone();
 398        let provider_id = self.extension_provider_id.clone();
 399
 400        cx.spawn(async move |this, cx| {
 401            let result = extension
 402                .call({
 403                    let provider_id = provider_id.clone();
 404                    |ext, store| {
 405                        async move {
 406                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 407                                .await
 408                        }
 409                        .boxed()
 410                    }
 411                })
 412                .await;
 413
 414            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 415
 416            this.update(cx, |this, cx| {
 417                this.loading_settings = false;
 418                if let Some(text) = settings_text {
 419                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 420                    this.settings_markdown = Some(markdown);
 421                }
 422                cx.notify();
 423            })
 424            .log_err();
 425        })
 426        .detach();
 427    }
 428
 429    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 430        let credential_key = self.credential_key.clone();
 431        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 432        let state = self.state.clone();
 433
 434        // Check if we should use env var (already set in state during provider construction)
 435        let api_key_from_env = self.state.read(cx).api_key_from_env;
 436
 437        cx.spawn(async move |this, cx| {
 438            // If using env var, we're already authenticated
 439            if api_key_from_env {
 440                this.update(cx, |this, cx| {
 441                    this.loading_credentials = false;
 442                    cx.notify();
 443                })
 444                .log_err();
 445                return;
 446            }
 447
 448            let credentials = credentials_provider
 449                .read_credentials(&credential_key, cx)
 450                .await
 451                .log_err()
 452                .flatten();
 453
 454            let has_credentials = credentials.is_some();
 455
 456            // Update authentication state based on stored credentials
 457            let _ = cx.update(|cx| {
 458                state.update(cx, |state, cx| {
 459                    state.is_authenticated = has_credentials;
 460                    cx.notify();
 461                });
 462            });
 463
 464            this.update(cx, |this, cx| {
 465                this.loading_credentials = false;
 466                cx.notify();
 467            })
 468            .log_err();
 469        })
 470        .detach();
 471    }
 472
 473    fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
 474        let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
 475        let env_var_name = match &self.auth_config {
 476            Some(config) => config.env_var.clone(),
 477            None => return,
 478        };
 479
 480        let state = self.state.clone();
 481        let currently_allowed = self.state.read(cx).env_var_allowed;
 482
 483        // Update settings file
 484        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
 485            let providers = settings
 486                .extension
 487                .allowed_env_var_providers
 488                .get_or_insert_with(Vec::new);
 489
 490            if currently_allowed {
 491                providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
 492            } else {
 493                if !providers
 494                    .iter()
 495                    .any(|id| id.as_ref() == full_provider_id.as_ref())
 496                {
 497                    providers.push(full_provider_id.clone());
 498                }
 499            }
 500        });
 501
 502        // Update local state
 503        let new_allowed = !currently_allowed;
 504        let new_from_env = if new_allowed {
 505            if let Some(var_name) = &env_var_name {
 506                if let Ok(value) = std::env::var(var_name) {
 507                    !value.is_empty()
 508                } else {
 509                    false
 510                }
 511            } else {
 512                false
 513            }
 514        } else {
 515            false
 516        };
 517
 518        state.update(cx, |state, cx| {
 519            state.env_var_allowed = new_allowed;
 520            state.api_key_from_env = new_from_env;
 521            if new_from_env {
 522                state.is_authenticated = true;
 523            }
 524            cx.notify();
 525        });
 526
 527        // If env var is being enabled, clear any stored keychain credentials
 528        // so there's only one source of truth for the API key
 529        if new_allowed {
 530            let credential_key = self.credential_key.clone();
 531            let credentials_provider = <dyn CredentialsProvider>::global(cx);
 532            cx.spawn(async move |_this, cx| {
 533                credentials_provider
 534                    .delete_credentials(&credential_key, cx)
 535                    .await
 536                    .log_err();
 537            })
 538            .detach();
 539        }
 540
 541        // If env var is being disabled, reload credentials from keychain
 542        if !new_allowed {
 543            self.reload_keychain_credentials(cx);
 544        }
 545
 546        cx.notify();
 547    }
 548
 549    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 550        let credential_key = self.credential_key.clone();
 551        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 552        let state = self.state.clone();
 553
 554        cx.spawn(async move |_this, cx| {
 555            let credentials = credentials_provider
 556                .read_credentials(&credential_key, cx)
 557                .await
 558                .log_err()
 559                .flatten();
 560
 561            let has_credentials = credentials.is_some();
 562
 563            let _ = cx.update(|cx| {
 564                state.update(cx, |state, cx| {
 565                    state.is_authenticated = has_credentials;
 566                    cx.notify();
 567                });
 568            });
 569        })
 570        .detach();
 571    }
 572
 573    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 574        let api_key = self.api_key_editor.read(cx).text(cx);
 575        if api_key.is_empty() {
 576            return;
 577        }
 578
 579        // Clear the editor
 580        self.api_key_editor
 581            .update(cx, |editor, cx| editor.set_text("", window, cx));
 582
 583        let credential_key = self.credential_key.clone();
 584        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 585        let state = self.state.clone();
 586
 587        cx.spawn(async move |_this, cx| {
 588            // Store in system keychain
 589            credentials_provider
 590                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 591                .await
 592                .log_err();
 593
 594            // Update state to authenticated
 595            let _ = cx.update(|cx| {
 596                state.update(cx, |state, cx| {
 597                    state.is_authenticated = true;
 598                    cx.notify();
 599                });
 600            });
 601        })
 602        .detach();
 603    }
 604
 605    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 606        // Clear the editor
 607        self.api_key_editor
 608            .update(cx, |editor, cx| editor.set_text("", window, cx));
 609
 610        let credential_key = self.credential_key.clone();
 611        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 612        let state = self.state.clone();
 613
 614        cx.spawn(async move |_this, cx| {
 615            // Delete from system keychain
 616            credentials_provider
 617                .delete_credentials(&credential_key, cx)
 618                .await
 619                .log_err();
 620
 621            // Update state to unauthenticated
 622            let _ = cx.update(|cx| {
 623                state.update(cx, |state, cx| {
 624                    state.is_authenticated = false;
 625                    cx.notify();
 626                });
 627            });
 628        })
 629        .detach();
 630    }
 631
 632    fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
 633        if self.oauth_in_progress {
 634            return;
 635        }
 636
 637        self.oauth_in_progress = true;
 638        self.oauth_error = None;
 639        self.device_user_code = None;
 640        cx.notify();
 641
 642        let extension = self.extension.clone();
 643        let provider_id = self.extension_provider_id.clone();
 644        let state = self.state.clone();
 645
 646        cx.spawn(async move |this, cx| {
 647            // Step 1: Start device flow - opens browser and returns user code
 648            let start_result = extension
 649                .call({
 650                    let provider_id = provider_id.clone();
 651                    |ext, store| {
 652                        async move {
 653                            ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
 654                                .await
 655                        }
 656                        .boxed()
 657                    }
 658                })
 659                .await;
 660
 661            let user_code = match start_result {
 662                Ok(Ok(Ok(code))) => code,
 663                Ok(Ok(Err(e))) => {
 664                    log::error!("Device flow start failed: {}", e);
 665                    this.update(cx, |this, cx| {
 666                        this.oauth_in_progress = false;
 667                        this.oauth_error = Some(e);
 668                        cx.notify();
 669                    })
 670                    .log_err();
 671                    return;
 672                }
 673                Ok(Err(e)) | Err(e) => {
 674                    log::error!("Device flow start error: {}", e);
 675                    this.update(cx, |this, cx| {
 676                        this.oauth_in_progress = false;
 677                        this.oauth_error = Some(e.to_string());
 678                        cx.notify();
 679                    })
 680                    .log_err();
 681                    return;
 682                }
 683            };
 684
 685            // Update UI to show the user code before polling
 686            this.update(cx, |this, cx| {
 687                this.device_user_code = Some(user_code);
 688                cx.notify();
 689            })
 690            .log_err();
 691
 692            // Step 2: Poll for authentication completion
 693            let poll_result = extension
 694                .call({
 695                    let provider_id = provider_id.clone();
 696                    |ext, store| {
 697                        async move {
 698                            ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
 699                                .await
 700                        }
 701                        .boxed()
 702                    }
 703                })
 704                .await;
 705
 706            let error_message = match poll_result {
 707                Ok(Ok(Ok(()))) => {
 708                    let _ = cx.update(|cx| {
 709                        state.update(cx, |state, cx| {
 710                            state.is_authenticated = true;
 711                            cx.notify();
 712                        });
 713                    });
 714                    None
 715                }
 716                Ok(Ok(Err(e))) => {
 717                    log::error!("Device flow poll failed: {}", e);
 718                    Some(e)
 719                }
 720                Ok(Err(e)) | Err(e) => {
 721                    log::error!("Device flow poll error: {}", e);
 722                    Some(e.to_string())
 723                }
 724            };
 725
 726            this.update(cx, |this, cx| {
 727                this.oauth_in_progress = false;
 728                this.oauth_error = error_message;
 729                this.device_user_code = None;
 730                cx.notify();
 731            })
 732            .log_err();
 733        })
 734        .detach();
 735    }
 736
 737    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 738        self.state.read(cx).is_authenticated
 739    }
 740
 741    fn has_oauth_config(&self) -> bool {
 742        self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
 743    }
 744
 745    fn oauth_config(&self) -> Option<&OAuthConfig> {
 746        self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
 747    }
 748
 749    fn has_api_key_config(&self) -> bool {
 750        // API key is available if there's a credential_label or no oauth-only config
 751        self.auth_config
 752            .as_ref()
 753            .map(|c| c.credential_label.is_some() || c.oauth.is_none())
 754            .unwrap_or(true)
 755    }
 756}
 757
 758impl gpui::Render for ExtensionProviderConfigurationView {
 759    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 760        let is_loading = self.loading_settings || self.loading_credentials;
 761        let is_authenticated = self.is_authenticated(cx);
 762        let env_var_allowed = self.state.read(cx).env_var_allowed;
 763        let api_key_from_env = self.state.read(cx).api_key_from_env;
 764        let has_oauth = self.has_oauth_config();
 765        let has_api_key = self.has_api_key_config();
 766
 767        if is_loading {
 768            return v_flex()
 769                .gap_2()
 770                .child(Label::new("Loading...").color(Color::Muted))
 771                .into_any_element();
 772        }
 773
 774        let mut content = v_flex().gap_4().size_full();
 775
 776        // Render settings markdown if available
 777        if let Some(markdown) = &self.settings_markdown {
 778            let style = settings_markdown_style(_window, cx);
 779            content = content.child(
 780                div()
 781                    .p_2()
 782                    .rounded_md()
 783                    .bg(cx.theme().colors().surface_background)
 784                    .child(MarkdownElement::new(markdown.clone(), style)),
 785            );
 786        }
 787
 788        // Render env var checkbox if the extension specifies an env var
 789        if let Some(auth_config) = &self.auth_config {
 790            if let Some(env_var_name) = &auth_config.env_var {
 791                let env_var_name = env_var_name.clone();
 792                let checkbox_label =
 793                    format!("Read API key from {} environment variable", env_var_name);
 794
 795                content = content.child(
 796                    h_flex()
 797                        .gap_2()
 798                        .child(
 799                            ui::Checkbox::new("env-var-permission", env_var_allowed.into())
 800                                .on_click(cx.listener(|this, _, _window, cx| {
 801                                    this.toggle_env_var_permission(cx);
 802                                })),
 803                        )
 804                        .child(Label::new(checkbox_label).size(LabelSize::Small)),
 805                );
 806
 807                // Show status if env var is allowed
 808                if env_var_allowed {
 809                    if api_key_from_env {
 810                        content = content.child(
 811                            h_flex()
 812                                .gap_2()
 813                                .child(
 814                                    ui::Icon::new(ui::IconName::Check)
 815                                        .color(Color::Success)
 816                                        .size(ui::IconSize::Small),
 817                                )
 818                                .child(
 819                                    Label::new(format!("API key loaded from {}", env_var_name))
 820                                        .color(Color::Success),
 821                                ),
 822                        );
 823                        return content.into_any_element();
 824                    } else {
 825                        content = content.child(
 826                            h_flex()
 827                                .gap_2()
 828                                .child(
 829                                    ui::Icon::new(ui::IconName::Warning)
 830                                        .color(Color::Warning)
 831                                        .size(ui::IconSize::Small),
 832                                )
 833                                .child(
 834                                    Label::new(format!(
 835                                        "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
 836                                        env_var_name
 837                                    ))
 838                                    .color(Color::Warning)
 839                                    .size(LabelSize::Small),
 840                                ),
 841                        );
 842                    }
 843                }
 844            }
 845        }
 846
 847        // If authenticated, show success state with sign out option
 848        if is_authenticated && !api_key_from_env {
 849            let reset_label = if has_oauth && !has_api_key {
 850                "Sign Out"
 851            } else {
 852                "Reset Credentials"
 853            };
 854
 855            let status_label = if has_oauth && !has_api_key {
 856                "Signed in"
 857            } else {
 858                "Authenticated"
 859            };
 860
 861            content = content.child(
 862                v_flex()
 863                    .gap_2()
 864                    .child(
 865                        h_flex()
 866                            .gap_2()
 867                            .child(
 868                                ui::Icon::new(ui::IconName::Check)
 869                                    .color(Color::Success)
 870                                    .size(ui::IconSize::Small),
 871                            )
 872                            .child(Label::new(status_label).color(Color::Success)),
 873                    )
 874                    .child(
 875                        ui::Button::new("reset-credentials", reset_label)
 876                            .style(ui::ButtonStyle::Subtle)
 877                            .on_click(cx.listener(|this, _, window, cx| {
 878                                this.reset_api_key(window, cx);
 879                            })),
 880                    ),
 881            );
 882
 883            return content.into_any_element();
 884        }
 885
 886        // Not authenticated - show available auth options
 887        if !api_key_from_env {
 888            // Render OAuth sign-in button if configured
 889            if has_oauth {
 890                let oauth_config = self.oauth_config();
 891                let button_label = oauth_config
 892                    .and_then(|c| c.sign_in_button_label.clone())
 893                    .unwrap_or_else(|| "Sign In".to_string());
 894
 895                let oauth_in_progress = self.oauth_in_progress;
 896
 897                let oauth_error = self.oauth_error.clone();
 898
 899                content = content.child(
 900                    v_flex()
 901                        .gap_2()
 902                        .child(
 903                            ui::Button::new("oauth-sign-in", button_label)
 904                                .style(ui::ButtonStyle::Filled)
 905                                .disabled(oauth_in_progress)
 906                                .on_click(cx.listener(|this, _, _window, cx| {
 907                                    this.start_oauth_sign_in(cx);
 908                                })),
 909                        )
 910                        .when(oauth_in_progress, |this| {
 911                            let user_code = self.device_user_code.clone();
 912                            this.child(
 913                                v_flex()
 914                                    .gap_1()
 915                                    .when_some(user_code, |this, code| {
 916                                        let copied = cx
 917                                            .read_from_clipboard()
 918                                            .map(|item| item.text().as_ref() == Some(&code))
 919                                            .unwrap_or(false);
 920                                        let code_for_click = code.clone();
 921                                        this.child(
 922                                            h_flex()
 923                                                .gap_1()
 924                                                .child(
 925                                                    Label::new("Enter code:")
 926                                                        .size(LabelSize::Small)
 927                                                        .color(Color::Muted),
 928                                                )
 929                                                .child(
 930                                                    h_flex()
 931                                                        .gap_1()
 932                                                        .px_1()
 933                                                        .border_1()
 934                                                        .border_color(cx.theme().colors().border)
 935                                                        .rounded_sm()
 936                                                        .cursor_pointer()
 937                                                        .on_mouse_down(
 938                                                            MouseButton::Left,
 939                                                            move |_, window, cx| {
 940                                                                cx.write_to_clipboard(
 941                                                                    ClipboardItem::new_string(
 942                                                                        code_for_click.clone(),
 943                                                                    ),
 944                                                                );
 945                                                                window.refresh();
 946                                                            },
 947                                                        )
 948                                                        .child(
 949                                                            Label::new(code)
 950                                                                .size(LabelSize::Small)
 951                                                                .color(Color::Accent),
 952                                                        )
 953                                                        .child(
 954                                                            ui::Icon::new(if copied {
 955                                                                ui::IconName::Check
 956                                                            } else {
 957                                                                ui::IconName::Copy
 958                                                            })
 959                                                            .size(ui::IconSize::Small)
 960                                                            .color(if copied {
 961                                                                Color::Success
 962                                                            } else {
 963                                                                Color::Muted
 964                                                            }),
 965                                                        ),
 966                                                ),
 967                                        )
 968                                    })
 969                                    .child(
 970                                        Label::new("Waiting for authorization in browser...")
 971                                            .size(LabelSize::Small)
 972                                            .color(Color::Muted),
 973                                    ),
 974                            )
 975                        })
 976                        .when_some(oauth_error, |this, error| {
 977                            this.child(
 978                                v_flex()
 979                                    .gap_1()
 980                                    .child(
 981                                        h_flex()
 982                                            .gap_2()
 983                                            .child(
 984                                                ui::Icon::new(ui::IconName::Warning)
 985                                                    .color(Color::Error)
 986                                                    .size(ui::IconSize::Small),
 987                                            )
 988                                            .child(
 989                                                Label::new("Authentication failed")
 990                                                    .color(Color::Error)
 991                                                    .size(LabelSize::Small),
 992                                            ),
 993                                    )
 994                                    .child(
 995                                        div().pl_6().child(
 996                                            Label::new(error)
 997                                                .color(Color::Error)
 998                                                .size(LabelSize::Small),
 999                                        ),
1000                                    ),
1001                            )
1002                        }),
1003                );
1004            }
1005
1006            // Render API key input if configured (and we have both options, show a separator)
1007            if has_api_key {
1008                if has_oauth {
1009                    content = content.child(
1010                        h_flex()
1011                            .gap_2()
1012                            .items_center()
1013                            .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1014                            .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1015                            .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1016                    );
1017                }
1018
1019                let credential_label = self
1020                    .auth_config
1021                    .as_ref()
1022                    .and_then(|c| c.credential_label.clone())
1023                    .unwrap_or_else(|| "API Key".to_string());
1024
1025                content = content.child(
1026                    v_flex()
1027                        .gap_2()
1028                        .on_action(cx.listener(Self::save_api_key))
1029                        .child(
1030                            Label::new(credential_label)
1031                                .size(LabelSize::Small)
1032                                .color(Color::Muted),
1033                        )
1034                        .child(self.api_key_editor.clone())
1035                        .child(
1036                            Label::new("Enter your API key and press Enter to save")
1037                                .size(LabelSize::Small)
1038                                .color(Color::Muted),
1039                        ),
1040                );
1041            }
1042        }
1043
1044        content.into_any_element()
1045    }
1046}
1047
1048impl Focusable for ExtensionProviderConfigurationView {
1049    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1050        self.api_key_editor.focus_handle(cx)
1051    }
1052}
1053
1054fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1055    let theme_settings = ThemeSettings::get_global(cx);
1056    let colors = cx.theme().colors();
1057    let mut text_style = window.text_style();
1058    text_style.refine(&TextStyleRefinement {
1059        font_family: Some(theme_settings.ui_font.family.clone()),
1060        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1061        font_features: Some(theme_settings.ui_font.features.clone()),
1062        color: Some(colors.text),
1063        ..Default::default()
1064    });
1065
1066    MarkdownStyle {
1067        base_text_style: text_style,
1068        selection_background_color: colors.element_selection_background,
1069        inline_code: TextStyleRefinement {
1070            background_color: Some(colors.editor_background),
1071            ..Default::default()
1072        },
1073        link: TextStyleRefinement {
1074            color: Some(colors.text_accent),
1075            underline: Some(UnderlineStyle {
1076                color: Some(colors.text_accent.opacity(0.5)),
1077                thickness: px(1.),
1078                ..Default::default()
1079            }),
1080            ..Default::default()
1081        },
1082        syntax: cx.theme().syntax().clone(),
1083        ..Default::default()
1084    }
1085}
1086
1087/// An extension-based language model.
1088pub struct ExtensionLanguageModel {
1089    extension: WasmExtension,
1090    model_info: LlmModelInfo,
1091    provider_id: LanguageModelProviderId,
1092    provider_name: LanguageModelProviderName,
1093    provider_info: LlmProviderInfo,
1094}
1095
1096impl LanguageModel for ExtensionLanguageModel {
1097    fn id(&self) -> LanguageModelId {
1098        LanguageModelId::from(self.model_info.id.clone())
1099    }
1100
1101    fn name(&self) -> LanguageModelName {
1102        LanguageModelName::from(self.model_info.name.clone())
1103    }
1104
1105    fn provider_id(&self) -> LanguageModelProviderId {
1106        self.provider_id.clone()
1107    }
1108
1109    fn provider_name(&self) -> LanguageModelProviderName {
1110        self.provider_name.clone()
1111    }
1112
1113    fn telemetry_id(&self) -> String {
1114        format!("extension-{}", self.model_info.id)
1115    }
1116
1117    fn supports_images(&self) -> bool {
1118        self.model_info.capabilities.supports_images
1119    }
1120
1121    fn supports_tools(&self) -> bool {
1122        self.model_info.capabilities.supports_tools
1123    }
1124
1125    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1126        match choice {
1127            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1128            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1129            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1130        }
1131    }
1132
1133    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1134        match self.model_info.capabilities.tool_input_format {
1135            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1136            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1137        }
1138    }
1139
1140    fn max_token_count(&self) -> u64 {
1141        self.model_info.max_token_count
1142    }
1143
1144    fn max_output_tokens(&self) -> Option<u64> {
1145        self.model_info.max_output_tokens
1146    }
1147
1148    fn count_tokens(
1149        &self,
1150        request: LanguageModelRequest,
1151        cx: &App,
1152    ) -> BoxFuture<'static, Result<u64>> {
1153        let extension = self.extension.clone();
1154        let provider_id = self.provider_info.id.clone();
1155        let model_id = self.model_info.id.clone();
1156
1157        let wit_request = convert_request_to_wit(request);
1158
1159        cx.background_spawn(async move {
1160            extension
1161                .call({
1162                    let provider_id = provider_id.clone();
1163                    let model_id = model_id.clone();
1164                    let wit_request = wit_request.clone();
1165                    |ext, store| {
1166                        async move {
1167                            let count = ext
1168                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1169                                .await?
1170                                .map_err(|e| anyhow!("{}", e))?;
1171                            Ok(count)
1172                        }
1173                        .boxed()
1174                    }
1175                })
1176                .await?
1177        })
1178        .boxed()
1179    }
1180
1181    fn stream_completion(
1182        &self,
1183        request: LanguageModelRequest,
1184        _cx: &AsyncApp,
1185    ) -> BoxFuture<
1186        'static,
1187        Result<
1188            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1189            LanguageModelCompletionError,
1190        >,
1191    > {
1192        let extension = self.extension.clone();
1193        let provider_id = self.provider_info.id.clone();
1194        let model_id = self.model_info.id.clone();
1195
1196        let wit_request = convert_request_to_wit(request);
1197
1198        async move {
1199            // Start the stream
1200            let stream_id_result = extension
1201                .call({
1202                    let provider_id = provider_id.clone();
1203                    let model_id = model_id.clone();
1204                    let wit_request = wit_request.clone();
1205                    |ext, store| {
1206                        async move {
1207                            let id = ext
1208                                .call_llm_stream_completion_start(
1209                                    store,
1210                                    &provider_id,
1211                                    &model_id,
1212                                    &wit_request,
1213                                )
1214                                .await?
1215                                .map_err(|e| anyhow!("{}", e))?;
1216                            Ok(id)
1217                        }
1218                        .boxed()
1219                    }
1220                })
1221                .await;
1222
1223            let stream_id = stream_id_result
1224                .map_err(LanguageModelCompletionError::Other)?
1225                .map_err(LanguageModelCompletionError::Other)?;
1226
1227            // Create a stream that polls for events
1228            let stream = futures::stream::unfold(
1229                (extension.clone(), stream_id, false),
1230                move |(extension, stream_id, done)| async move {
1231                    if done {
1232                        return None;
1233                    }
1234
1235                    let result = extension
1236                        .call({
1237                            let stream_id = stream_id.clone();
1238                            |ext, store| {
1239                                async move {
1240                                    let event = ext
1241                                        .call_llm_stream_completion_next(store, &stream_id)
1242                                        .await?
1243                                        .map_err(|e| anyhow!("{}", e))?;
1244                                    Ok(event)
1245                                }
1246                                .boxed()
1247                            }
1248                        })
1249                        .await
1250                        .and_then(|inner| inner);
1251
1252                    match result {
1253                        Ok(Some(event)) => {
1254                            let converted = convert_completion_event(event);
1255                            let is_done =
1256                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1257                            Some((converted, (extension, stream_id, is_done)))
1258                        }
1259                        Ok(None) => {
1260                            // Stream complete, close it
1261                            let _ = extension
1262                                .call({
1263                                    let stream_id = stream_id.clone();
1264                                    |ext, store| {
1265                                        async move {
1266                                            ext.call_llm_stream_completion_close(store, &stream_id)
1267                                                .await?;
1268                                            Ok::<(), anyhow::Error>(())
1269                                        }
1270                                        .boxed()
1271                                    }
1272                                })
1273                                .await;
1274                            None
1275                        }
1276                        Err(e) => Some((
1277                            Err(LanguageModelCompletionError::Other(e)),
1278                            (extension, stream_id, true),
1279                        )),
1280                    }
1281                },
1282            );
1283
1284            Ok(stream.boxed())
1285        }
1286        .boxed()
1287    }
1288
1289    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1290        // Extensions can implement this via llm_cache_configuration
1291        None
1292    }
1293}
1294
1295fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1296    use language_model::{MessageContent, Role};
1297
1298    let messages: Vec<LlmRequestMessage> = request
1299        .messages
1300        .into_iter()
1301        .map(|msg| {
1302            let role = match msg.role {
1303                Role::User => LlmMessageRole::User,
1304                Role::Assistant => LlmMessageRole::Assistant,
1305                Role::System => LlmMessageRole::System,
1306            };
1307
1308            let content: Vec<LlmMessageContent> = msg
1309                .content
1310                .into_iter()
1311                .map(|c| match c {
1312                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1313                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1314                        source: image.source.to_string(),
1315                        width: Some(image.size.width.0 as u32),
1316                        height: Some(image.size.height.0 as u32),
1317                    }),
1318                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1319                        id: tool_use.id.to_string(),
1320                        name: tool_use.name.to_string(),
1321                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1322                        thought_signature: tool_use.thought_signature,
1323                    }),
1324                    MessageContent::ToolResult(tool_result) => {
1325                        let content = match tool_result.content {
1326                            language_model::LanguageModelToolResultContent::Text(text) => {
1327                                LlmToolResultContent::Text(text.to_string())
1328                            }
1329                            language_model::LanguageModelToolResultContent::Image(image) => {
1330                                LlmToolResultContent::Image(LlmImageData {
1331                                    source: image.source.to_string(),
1332                                    width: Some(image.size.width.0 as u32),
1333                                    height: Some(image.size.height.0 as u32),
1334                                })
1335                            }
1336                        };
1337                        LlmMessageContent::ToolResult(LlmToolResult {
1338                            tool_use_id: tool_result.tool_use_id.to_string(),
1339                            tool_name: tool_result.tool_name.to_string(),
1340                            is_error: tool_result.is_error,
1341                            content,
1342                        })
1343                    }
1344                    MessageContent::Thinking { text, signature } => {
1345                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1346                    }
1347                    MessageContent::RedactedThinking(data) => {
1348                        LlmMessageContent::RedactedThinking(data)
1349                    }
1350                })
1351                .collect();
1352
1353            LlmRequestMessage {
1354                role,
1355                content,
1356                cache: msg.cache,
1357            }
1358        })
1359        .collect();
1360
1361    let tools: Vec<LlmToolDefinition> = request
1362        .tools
1363        .into_iter()
1364        .map(|tool| LlmToolDefinition {
1365            name: tool.name,
1366            description: tool.description,
1367            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1368        })
1369        .collect();
1370
1371    let tool_choice = request.tool_choice.map(|tc| match tc {
1372        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1373        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1374        LanguageModelToolChoice::None => LlmToolChoice::None,
1375    });
1376
1377    LlmCompletionRequest {
1378        messages,
1379        tools,
1380        tool_choice,
1381        stop_sequences: request.stop,
1382        temperature: request.temperature,
1383        thinking_allowed: false,
1384        max_tokens: None,
1385    }
1386}
1387
1388fn convert_completion_event(
1389    event: LlmCompletionEvent,
1390) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1391    match event {
1392        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1393            message_id: String::new(),
1394        }),
1395        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1396        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1397            text: thinking.text,
1398            signature: thinking.signature,
1399        }),
1400        LlmCompletionEvent::RedactedThinking(data) => {
1401            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1402        }
1403        LlmCompletionEvent::ToolUse(tool_use) => {
1404            let raw_input = tool_use.input.clone();
1405            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1406            Ok(LanguageModelCompletionEvent::ToolUse(
1407                LanguageModelToolUse {
1408                    id: LanguageModelToolUseId::from(tool_use.id),
1409                    name: tool_use.name.into(),
1410                    raw_input,
1411                    input,
1412                    is_input_complete: true,
1413                    thought_signature: tool_use.thought_signature,
1414                },
1415            ))
1416        }
1417        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1418            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1419                id: LanguageModelToolUseId::from(error.id),
1420                tool_name: error.tool_name.into(),
1421                raw_input: error.raw_input.into(),
1422                json_parse_error: error.error,
1423            })
1424        }
1425        LlmCompletionEvent::Stop(reason) => {
1426            let stop_reason = match reason {
1427                LlmStopReason::EndTurn => StopReason::EndTurn,
1428                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1429                LlmStopReason::ToolUse => StopReason::ToolUse,
1430                LlmStopReason::Refusal => StopReason::Refusal,
1431            };
1432            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1433        }
1434        LlmCompletionEvent::Usage(usage) => {
1435            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1436                input_tokens: usage.input_tokens,
1437                output_tokens: usage.output_tokens,
1438                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1439                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1440            }))
1441        }
1442        LlmCompletionEvent::ReasoningDetails(json) => {
1443            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1444                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1445            ))
1446        }
1447    }
1448}