llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::wasm_host::WasmExtension;
   3
   4use crate::wasm_host::wit::{
   5    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   6    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
   7    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
   8    LlmToolUse,
   9};
  10use anyhow::{Result, anyhow};
  11use credentials_provider::CredentialsProvider;
  12use editor::Editor;
  13use extension::{LanguageModelAuthConfig, OAuthConfig};
  14use futures::future::BoxFuture;
  15use futures::stream::BoxStream;
  16use futures::{FutureExt, StreamExt};
  17use gpui::Focusable;
  18use gpui::{
  19    AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
  20    MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
  21};
  22use language_model::tool_schema::LanguageModelToolSchemaFormat;
  23use language_model::{
  24    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  25    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  27    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  28    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  29};
  30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  31use settings::Settings;
  32use std::sync::Arc;
  33use theme::ThemeSettings;
  34use ui::{Label, LabelSize, prelude::*};
  35use util::ResultExt as _;
  36
  37/// An extension-based language model provider.
  38pub struct ExtensionLanguageModelProvider {
  39    pub extension: WasmExtension,
  40    pub provider_info: LlmProviderInfo,
  41    icon_path: Option<SharedString>,
  42    auth_config: Option<LanguageModelAuthConfig>,
  43    state: Entity<ExtensionLlmProviderState>,
  44}
  45
  46pub struct ExtensionLlmProviderState {
  47    is_authenticated: bool,
  48    available_models: Vec<LlmModelInfo>,
  49    env_var_allowed: bool,
  50    api_key_from_env: bool,
  51}
  52
  53impl EventEmitter<()> for ExtensionLlmProviderState {}
  54
  55impl ExtensionLanguageModelProvider {
  56    pub fn new(
  57        extension: WasmExtension,
  58        provider_info: LlmProviderInfo,
  59        models: Vec<LlmModelInfo>,
  60        is_authenticated: bool,
  61        icon_path: Option<SharedString>,
  62        auth_config: Option<LanguageModelAuthConfig>,
  63        cx: &mut App,
  64    ) -> Self {
  65        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  66        let env_var_allowed = ExtensionSettings::get_global(cx)
  67            .allowed_env_var_providers
  68            .contains(provider_id_string.as_str());
  69
  70        let (is_authenticated, api_key_from_env) =
  71            if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
  72                let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
  73                if let Ok(value) = std::env::var(env_var_name) {
  74                    if !value.is_empty() {
  75                        (true, true)
  76                    } else {
  77                        (is_authenticated, false)
  78                    }
  79                } else {
  80                    (is_authenticated, false)
  81                }
  82            } else {
  83                (is_authenticated, false)
  84            };
  85
  86        let state = cx.new(|_| ExtensionLlmProviderState {
  87            is_authenticated,
  88            available_models: models,
  89            env_var_allowed,
  90            api_key_from_env,
  91        });
  92
  93        Self {
  94            extension,
  95            provider_info,
  96            icon_path,
  97            auth_config,
  98            state,
  99        }
 100    }
 101
 102    fn provider_id_string(&self) -> String {
 103        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 104    }
 105
 106    /// The credential key used for storing the API key in the system keychain.
 107    fn credential_key(&self) -> String {
 108        format!("extension-llm-{}", self.provider_id_string())
 109    }
 110}
 111
 112impl LanguageModelProvider for ExtensionLanguageModelProvider {
 113    fn id(&self) -> LanguageModelProviderId {
 114        LanguageModelProviderId::from(self.provider_id_string())
 115    }
 116
 117    fn name(&self) -> LanguageModelProviderName {
 118        LanguageModelProviderName::from(self.provider_info.name.clone())
 119    }
 120
 121    fn icon(&self) -> ui::IconName {
 122        ui::IconName::ZedAssistant
 123    }
 124
 125    fn icon_path(&self) -> Option<SharedString> {
 126        self.icon_path.clone()
 127    }
 128
 129    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 130        let state = self.state.read(cx);
 131        state
 132            .available_models
 133            .iter()
 134            .find(|m| m.is_default)
 135            .or_else(|| state.available_models.first())
 136            .map(|model_info| {
 137                Arc::new(ExtensionLanguageModel {
 138                    extension: self.extension.clone(),
 139                    model_info: model_info.clone(),
 140                    provider_id: self.id(),
 141                    provider_name: self.name(),
 142                    provider_info: self.provider_info.clone(),
 143                }) as Arc<dyn LanguageModel>
 144            })
 145    }
 146
 147    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 148        let state = self.state.read(cx);
 149        state
 150            .available_models
 151            .iter()
 152            .find(|m| m.is_default_fast)
 153            .map(|model_info| {
 154                Arc::new(ExtensionLanguageModel {
 155                    extension: self.extension.clone(),
 156                    model_info: model_info.clone(),
 157                    provider_id: self.id(),
 158                    provider_name: self.name(),
 159                    provider_info: self.provider_info.clone(),
 160                }) as Arc<dyn LanguageModel>
 161            })
 162    }
 163
 164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 165        let state = self.state.read(cx);
 166        state
 167            .available_models
 168            .iter()
 169            .map(|model_info| {
 170                Arc::new(ExtensionLanguageModel {
 171                    extension: self.extension.clone(),
 172                    model_info: model_info.clone(),
 173                    provider_id: self.id(),
 174                    provider_name: self.name(),
 175                    provider_info: self.provider_info.clone(),
 176                }) as Arc<dyn LanguageModel>
 177            })
 178            .collect()
 179    }
 180
 181    fn is_authenticated(&self, cx: &App) -> bool {
 182        self.state.read(cx).is_authenticated
 183    }
 184
 185    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 186        let extension = self.extension.clone();
 187        let provider_id = self.provider_info.id.clone();
 188        let state = self.state.clone();
 189
 190        cx.spawn(async move |cx| {
 191            let result = extension
 192                .call(|extension, store| {
 193                    async move {
 194                        extension
 195                            .call_llm_provider_authenticate(store, &provider_id)
 196                            .await
 197                    }
 198                    .boxed()
 199                })
 200                .await;
 201
 202            match result {
 203                Ok(Ok(Ok(()))) => {
 204                    cx.update(|cx| {
 205                        state.update(cx, |state, _| {
 206                            state.is_authenticated = true;
 207                        });
 208                    })?;
 209                    Ok(())
 210                }
 211                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
 212                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
 213                Err(e) => Err(AuthenticateError::Other(e)),
 214            }
 215        })
 216    }
 217
 218    fn configuration_view(
 219        &self,
 220        _target_agent: ConfigurationViewTargetAgent,
 221        window: &mut Window,
 222        cx: &mut App,
 223    ) -> AnyView {
 224        let credential_key = self.credential_key();
 225        let extension = self.extension.clone();
 226        let extension_provider_id = self.provider_info.id.clone();
 227        let full_provider_id = self.provider_id_string();
 228        let state = self.state.clone();
 229        let auth_config = self.auth_config.clone();
 230
 231        cx.new(|cx| {
 232            ExtensionProviderConfigurationView::new(
 233                credential_key,
 234                extension,
 235                extension_provider_id,
 236                full_provider_id,
 237                auth_config,
 238                state,
 239                window,
 240                cx,
 241            )
 242        })
 243        .into()
 244    }
 245
 246    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 247        let extension = self.extension.clone();
 248        let provider_id = self.provider_info.id.clone();
 249        let state = self.state.clone();
 250        let credential_key = self.credential_key();
 251
 252        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 253
 254        cx.spawn(async move |cx| {
 255            // Delete from system keychain
 256            credentials_provider
 257                .delete_credentials(&credential_key, cx)
 258                .await
 259                .log_err();
 260
 261            // Call extension's reset_credentials
 262            let result = extension
 263                .call(|extension, store| {
 264                    async move {
 265                        extension
 266                            .call_llm_provider_reset_credentials(store, &provider_id)
 267                            .await
 268                    }
 269                    .boxed()
 270                })
 271                .await;
 272
 273            // Update state
 274            cx.update(|cx| {
 275                state.update(cx, |state, _| {
 276                    state.is_authenticated = false;
 277                });
 278            })?;
 279
 280            match result {
 281                Ok(Ok(Ok(()))) => Ok(()),
 282                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 283                Ok(Err(e)) => Err(e),
 284                Err(e) => Err(e),
 285            }
 286        })
 287    }
 288}
 289
 290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 291    type ObservableEntity = ExtensionLlmProviderState;
 292
 293    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 294        Some(self.state.clone())
 295    }
 296
 297    fn subscribe<T: 'static>(
 298        &self,
 299        cx: &mut Context<T>,
 300        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 301    ) -> Option<Subscription> {
 302        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 303    }
 304}
 305
 306/// Configuration view for extension-based LLM providers.
 307struct ExtensionProviderConfigurationView {
 308    credential_key: String,
 309    extension: WasmExtension,
 310    extension_provider_id: String,
 311    full_provider_id: String,
 312    auth_config: Option<LanguageModelAuthConfig>,
 313    state: Entity<ExtensionLlmProviderState>,
 314    settings_markdown: Option<Entity<Markdown>>,
 315    api_key_editor: Entity<Editor>,
 316    loading_settings: bool,
 317    loading_credentials: bool,
 318    oauth_in_progress: bool,
 319    oauth_error: Option<String>,
 320    device_user_code: Option<String>,
 321    _subscriptions: Vec<Subscription>,
 322}
 323
 324impl ExtensionProviderConfigurationView {
 325    fn new(
 326        credential_key: String,
 327        extension: WasmExtension,
 328        extension_provider_id: String,
 329        full_provider_id: String,
 330        auth_config: Option<LanguageModelAuthConfig>,
 331        state: Entity<ExtensionLlmProviderState>,
 332        window: &mut Window,
 333        cx: &mut Context<Self>,
 334    ) -> Self {
 335        // Subscribe to state changes
 336        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 337            cx.notify();
 338        });
 339
 340        // Create API key editor
 341        let api_key_editor = cx.new(|cx| {
 342            let mut editor = Editor::single_line(window, cx);
 343            editor.set_placeholder_text("Enter API key...", window, cx);
 344            editor
 345        });
 346
 347        let mut this = Self {
 348            credential_key,
 349            extension,
 350            extension_provider_id,
 351            full_provider_id,
 352            auth_config,
 353            state,
 354            settings_markdown: None,
 355            api_key_editor,
 356            loading_settings: true,
 357            loading_credentials: true,
 358            oauth_in_progress: false,
 359            oauth_error: None,
 360            device_user_code: None,
 361            _subscriptions: vec![state_subscription],
 362        };
 363
 364        // Load settings text from extension
 365        this.load_settings_text(cx);
 366
 367        // Load existing credentials
 368        this.load_credentials(cx);
 369
 370        this
 371    }
 372
 373    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 374        let extension = self.extension.clone();
 375        let provider_id = self.extension_provider_id.clone();
 376
 377        cx.spawn(async move |this, cx| {
 378            let result = extension
 379                .call({
 380                    let provider_id = provider_id.clone();
 381                    |ext, store| {
 382                        async move {
 383                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 384                                .await
 385                        }
 386                        .boxed()
 387                    }
 388                })
 389                .await;
 390
 391            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 392
 393            this.update(cx, |this, cx| {
 394                this.loading_settings = false;
 395                if let Some(text) = settings_text {
 396                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 397                    this.settings_markdown = Some(markdown);
 398                }
 399                cx.notify();
 400            })
 401            .log_err();
 402        })
 403        .detach();
 404    }
 405
 406    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 407        let credential_key = self.credential_key.clone();
 408        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 409        let state = self.state.clone();
 410
 411        // Check if we should use env var (already set in state during provider construction)
 412        let api_key_from_env = self.state.read(cx).api_key_from_env;
 413
 414        cx.spawn(async move |this, cx| {
 415            // If using env var, we're already authenticated
 416            if api_key_from_env {
 417                this.update(cx, |this, cx| {
 418                    this.loading_credentials = false;
 419                    cx.notify();
 420                })
 421                .log_err();
 422                return;
 423            }
 424
 425            let credentials = credentials_provider
 426                .read_credentials(&credential_key, cx)
 427                .await
 428                .log_err()
 429                .flatten();
 430
 431            let has_credentials = credentials.is_some();
 432
 433            // Update authentication state based on stored credentials
 434            let _ = cx.update(|cx| {
 435                state.update(cx, |state, cx| {
 436                    state.is_authenticated = has_credentials;
 437                    cx.notify();
 438                });
 439            });
 440
 441            this.update(cx, |this, cx| {
 442                this.loading_credentials = false;
 443                cx.notify();
 444            })
 445            .log_err();
 446        })
 447        .detach();
 448    }
 449
 450    fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
 451        let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
 452        let env_var_name = match &self.auth_config {
 453            Some(config) => config.env_var.clone(),
 454            None => return,
 455        };
 456
 457        let state = self.state.clone();
 458        let currently_allowed = self.state.read(cx).env_var_allowed;
 459
 460        // Update settings file
 461        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
 462            let providers = settings
 463                .extension
 464                .allowed_env_var_providers
 465                .get_or_insert_with(Vec::new);
 466
 467            if currently_allowed {
 468                providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
 469            } else {
 470                if !providers
 471                    .iter()
 472                    .any(|id| id.as_ref() == full_provider_id.as_ref())
 473                {
 474                    providers.push(full_provider_id.clone());
 475                }
 476            }
 477        });
 478
 479        // Update local state
 480        let new_allowed = !currently_allowed;
 481        let new_from_env = if new_allowed {
 482            if let Some(var_name) = &env_var_name {
 483                if let Ok(value) = std::env::var(var_name) {
 484                    !value.is_empty()
 485                } else {
 486                    false
 487                }
 488            } else {
 489                false
 490            }
 491        } else {
 492            false
 493        };
 494
 495        state.update(cx, |state, cx| {
 496            state.env_var_allowed = new_allowed;
 497            state.api_key_from_env = new_from_env;
 498            if new_from_env {
 499                state.is_authenticated = true;
 500            }
 501            cx.notify();
 502        });
 503
 504        // If env var is being enabled, clear any stored keychain credentials
 505        // so there's only one source of truth for the API key
 506        if new_allowed {
 507            let credential_key = self.credential_key.clone();
 508            let credentials_provider = <dyn CredentialsProvider>::global(cx);
 509            cx.spawn(async move |_this, cx| {
 510                credentials_provider
 511                    .delete_credentials(&credential_key, cx)
 512                    .await
 513                    .log_err();
 514            })
 515            .detach();
 516        }
 517
 518        // If env var is being disabled, reload credentials from keychain
 519        if !new_allowed {
 520            self.reload_keychain_credentials(cx);
 521        }
 522
 523        cx.notify();
 524    }
 525
 526    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 527        let credential_key = self.credential_key.clone();
 528        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 529        let state = self.state.clone();
 530
 531        cx.spawn(async move |_this, cx| {
 532            let credentials = credentials_provider
 533                .read_credentials(&credential_key, cx)
 534                .await
 535                .log_err()
 536                .flatten();
 537
 538            let has_credentials = credentials.is_some();
 539
 540            let _ = cx.update(|cx| {
 541                state.update(cx, |state, cx| {
 542                    state.is_authenticated = has_credentials;
 543                    cx.notify();
 544                });
 545            });
 546        })
 547        .detach();
 548    }
 549
 550    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 551        let api_key = self.api_key_editor.read(cx).text(cx);
 552        if api_key.is_empty() {
 553            return;
 554        }
 555
 556        // Clear the editor
 557        self.api_key_editor
 558            .update(cx, |editor, cx| editor.set_text("", window, cx));
 559
 560        let credential_key = self.credential_key.clone();
 561        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 562        let state = self.state.clone();
 563
 564        cx.spawn(async move |_this, cx| {
 565            // Store in system keychain
 566            credentials_provider
 567                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 568                .await
 569                .log_err();
 570
 571            // Update state to authenticated
 572            let _ = cx.update(|cx| {
 573                state.update(cx, |state, cx| {
 574                    state.is_authenticated = true;
 575                    cx.notify();
 576                });
 577            });
 578        })
 579        .detach();
 580    }
 581
 582    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 583        // Clear the editor
 584        self.api_key_editor
 585            .update(cx, |editor, cx| editor.set_text("", window, cx));
 586
 587        let credential_key = self.credential_key.clone();
 588        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 589        let state = self.state.clone();
 590
 591        cx.spawn(async move |_this, cx| {
 592            // Delete from system keychain
 593            credentials_provider
 594                .delete_credentials(&credential_key, cx)
 595                .await
 596                .log_err();
 597
 598            // Update state to unauthenticated
 599            let _ = cx.update(|cx| {
 600                state.update(cx, |state, cx| {
 601                    state.is_authenticated = false;
 602                    cx.notify();
 603                });
 604            });
 605        })
 606        .detach();
 607    }
 608
 609    fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
 610        if self.oauth_in_progress {
 611            return;
 612        }
 613
 614        self.oauth_in_progress = true;
 615        self.oauth_error = None;
 616        self.device_user_code = None;
 617        cx.notify();
 618
 619        let extension = self.extension.clone();
 620        let provider_id = self.extension_provider_id.clone();
 621        let state = self.state.clone();
 622
 623        cx.spawn(async move |this, cx| {
 624            // Step 1: Start device flow - opens browser and returns user code
 625            let start_result = extension
 626                .call({
 627                    let provider_id = provider_id.clone();
 628                    |ext, store| {
 629                        async move {
 630                            ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
 631                                .await
 632                        }
 633                        .boxed()
 634                    }
 635                })
 636                .await;
 637
 638            let user_code = match start_result {
 639                Ok(Ok(Ok(code))) => code,
 640                Ok(Ok(Err(e))) => {
 641                    log::error!("Device flow start failed: {}", e);
 642                    this.update(cx, |this, cx| {
 643                        this.oauth_in_progress = false;
 644                        this.oauth_error = Some(e);
 645                        cx.notify();
 646                    })
 647                    .log_err();
 648                    return;
 649                }
 650                Ok(Err(e)) | Err(e) => {
 651                    log::error!("Device flow start error: {}", e);
 652                    this.update(cx, |this, cx| {
 653                        this.oauth_in_progress = false;
 654                        this.oauth_error = Some(e.to_string());
 655                        cx.notify();
 656                    })
 657                    .log_err();
 658                    return;
 659                }
 660            };
 661
 662            // Update UI to show the user code before polling
 663            this.update(cx, |this, cx| {
 664                this.device_user_code = Some(user_code);
 665                cx.notify();
 666            })
 667            .log_err();
 668
 669            // Step 2: Poll for authentication completion
 670            let poll_result = extension
 671                .call({
 672                    let provider_id = provider_id.clone();
 673                    |ext, store| {
 674                        async move {
 675                            ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
 676                                .await
 677                        }
 678                        .boxed()
 679                    }
 680                })
 681                .await;
 682
 683            let error_message = match poll_result {
 684                Ok(Ok(Ok(()))) => {
 685                    let _ = cx.update(|cx| {
 686                        state.update(cx, |state, cx| {
 687                            state.is_authenticated = true;
 688                            cx.notify();
 689                        });
 690                    });
 691                    None
 692                }
 693                Ok(Ok(Err(e))) => {
 694                    log::error!("Device flow poll failed: {}", e);
 695                    Some(e)
 696                }
 697                Ok(Err(e)) | Err(e) => {
 698                    log::error!("Device flow poll error: {}", e);
 699                    Some(e.to_string())
 700                }
 701            };
 702
 703            this.update(cx, |this, cx| {
 704                this.oauth_in_progress = false;
 705                this.oauth_error = error_message;
 706                this.device_user_code = None;
 707                cx.notify();
 708            })
 709            .log_err();
 710        })
 711        .detach();
 712    }
 713
 714    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 715        self.state.read(cx).is_authenticated
 716    }
 717
 718    fn has_oauth_config(&self) -> bool {
 719        self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
 720    }
 721
 722    fn oauth_config(&self) -> Option<&OAuthConfig> {
 723        self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
 724    }
 725
 726    fn has_api_key_config(&self) -> bool {
 727        // API key is available if there's a credential_label or no oauth-only config
 728        self.auth_config
 729            .as_ref()
 730            .map(|c| c.credential_label.is_some() || c.oauth.is_none())
 731            .unwrap_or(true)
 732    }
 733}
 734
 735impl gpui::Render for ExtensionProviderConfigurationView {
 736    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 737        let is_loading = self.loading_settings || self.loading_credentials;
 738        let is_authenticated = self.is_authenticated(cx);
 739        let env_var_allowed = self.state.read(cx).env_var_allowed;
 740        let api_key_from_env = self.state.read(cx).api_key_from_env;
 741        let has_oauth = self.has_oauth_config();
 742        let has_api_key = self.has_api_key_config();
 743
 744        if is_loading {
 745            return v_flex()
 746                .gap_2()
 747                .child(Label::new("Loading...").color(Color::Muted))
 748                .into_any_element();
 749        }
 750
 751        let mut content = v_flex().gap_4().size_full();
 752
 753        // Render settings markdown if available
 754        if let Some(markdown) = &self.settings_markdown {
 755            let style = settings_markdown_style(_window, cx);
 756            content = content.child(
 757                div()
 758                    .p_2()
 759                    .rounded_md()
 760                    .bg(cx.theme().colors().surface_background)
 761                    .child(MarkdownElement::new(markdown.clone(), style)),
 762            );
 763        }
 764
 765        // Render env var checkbox if the extension specifies an env var
 766        if let Some(auth_config) = &self.auth_config {
 767            if let Some(env_var_name) = &auth_config.env_var {
 768                let env_var_name = env_var_name.clone();
 769                let checkbox_label =
 770                    format!("Read API key from {} environment variable", env_var_name);
 771
 772                content = content.child(
 773                    h_flex()
 774                        .gap_2()
 775                        .child(
 776                            ui::Checkbox::new("env-var-permission", env_var_allowed.into())
 777                                .on_click(cx.listener(|this, _, _window, cx| {
 778                                    this.toggle_env_var_permission(cx);
 779                                })),
 780                        )
 781                        .child(Label::new(checkbox_label).size(LabelSize::Small)),
 782                );
 783
 784                // Show status if env var is allowed
 785                if env_var_allowed {
 786                    if api_key_from_env {
 787                        content = content.child(
 788                            h_flex()
 789                                .gap_2()
 790                                .child(
 791                                    ui::Icon::new(ui::IconName::Check)
 792                                        .color(Color::Success)
 793                                        .size(ui::IconSize::Small),
 794                                )
 795                                .child(
 796                                    Label::new(format!("API key loaded from {}", env_var_name))
 797                                        .color(Color::Success),
 798                                ),
 799                        );
 800                        return content.into_any_element();
 801                    } else {
 802                        content = content.child(
 803                            h_flex()
 804                                .gap_2()
 805                                .child(
 806                                    ui::Icon::new(ui::IconName::Warning)
 807                                        .color(Color::Warning)
 808                                        .size(ui::IconSize::Small),
 809                                )
 810                                .child(
 811                                    Label::new(format!(
 812                                        "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
 813                                        env_var_name
 814                                    ))
 815                                    .color(Color::Warning)
 816                                    .size(LabelSize::Small),
 817                                ),
 818                        );
 819                    }
 820                }
 821            }
 822        }
 823
 824        // If authenticated, show success state with sign out option
 825        if is_authenticated && !api_key_from_env {
 826            let reset_label = if has_oauth && !has_api_key {
 827                "Sign Out"
 828            } else {
 829                "Reset Credentials"
 830            };
 831
 832            let status_label = if has_oauth && !has_api_key {
 833                "Signed in"
 834            } else {
 835                "Authenticated"
 836            };
 837
 838            content = content.child(
 839                v_flex()
 840                    .gap_2()
 841                    .child(
 842                        h_flex()
 843                            .gap_2()
 844                            .child(
 845                                ui::Icon::new(ui::IconName::Check)
 846                                    .color(Color::Success)
 847                                    .size(ui::IconSize::Small),
 848                            )
 849                            .child(Label::new(status_label).color(Color::Success)),
 850                    )
 851                    .child(
 852                        ui::Button::new("reset-credentials", reset_label)
 853                            .style(ui::ButtonStyle::Subtle)
 854                            .on_click(cx.listener(|this, _, window, cx| {
 855                                this.reset_api_key(window, cx);
 856                            })),
 857                    ),
 858            );
 859
 860            return content.into_any_element();
 861        }
 862
 863        // Not authenticated - show available auth options
 864        if !api_key_from_env {
 865            // Render OAuth sign-in button if configured
 866            if has_oauth {
 867                let oauth_config = self.oauth_config();
 868                let button_label = oauth_config
 869                    .and_then(|c| c.sign_in_button_label.clone())
 870                    .unwrap_or_else(|| "Sign In".to_string());
 871
 872                let oauth_in_progress = self.oauth_in_progress;
 873
 874                let oauth_error = self.oauth_error.clone();
 875
 876                content = content.child(
 877                    v_flex()
 878                        .gap_2()
 879                        .child(
 880                            ui::Button::new("oauth-sign-in", button_label)
 881                                .style(ui::ButtonStyle::Filled)
 882                                .disabled(oauth_in_progress)
 883                                .on_click(cx.listener(|this, _, _window, cx| {
 884                                    this.start_oauth_sign_in(cx);
 885                                })),
 886                        )
 887                        .when(oauth_in_progress, |this| {
 888                            let user_code = self.device_user_code.clone();
 889                            this.child(
 890                                v_flex()
 891                                    .gap_1()
 892                                    .when_some(user_code, |this, code| {
 893                                        let copied = cx
 894                                            .read_from_clipboard()
 895                                            .map(|item| item.text().as_ref() == Some(&code))
 896                                            .unwrap_or(false);
 897                                        let code_for_click = code.clone();
 898                                        this.child(
 899                                            h_flex()
 900                                                .gap_1()
 901                                                .child(
 902                                                    Label::new("Enter code:")
 903                                                        .size(LabelSize::Small)
 904                                                        .color(Color::Muted),
 905                                                )
 906                                                .child(
 907                                                    h_flex()
 908                                                        .gap_1()
 909                                                        .px_1()
 910                                                        .border_1()
 911                                                        .border_color(cx.theme().colors().border)
 912                                                        .rounded_sm()
 913                                                        .cursor_pointer()
 914                                                        .on_mouse_down(
 915                                                            MouseButton::Left,
 916                                                            move |_, window, cx| {
 917                                                                cx.write_to_clipboard(
 918                                                                    ClipboardItem::new_string(
 919                                                                        code_for_click.clone(),
 920                                                                    ),
 921                                                                );
 922                                                                window.refresh();
 923                                                            },
 924                                                        )
 925                                                        .child(
 926                                                            Label::new(code)
 927                                                                .size(LabelSize::Small)
 928                                                                .color(Color::Accent),
 929                                                        )
 930                                                        .child(
 931                                                            ui::Icon::new(if copied {
 932                                                                ui::IconName::Check
 933                                                            } else {
 934                                                                ui::IconName::Copy
 935                                                            })
 936                                                            .size(ui::IconSize::Small)
 937                                                            .color(if copied {
 938                                                                Color::Success
 939                                                            } else {
 940                                                                Color::Muted
 941                                                            }),
 942                                                        ),
 943                                                ),
 944                                        )
 945                                    })
 946                                    .child(
 947                                        Label::new("Waiting for authorization in browser...")
 948                                            .size(LabelSize::Small)
 949                                            .color(Color::Muted),
 950                                    ),
 951                            )
 952                        })
 953                        .when_some(oauth_error, |this, error| {
 954                            this.child(
 955                                v_flex()
 956                                    .gap_1()
 957                                    .child(
 958                                        h_flex()
 959                                            .gap_2()
 960                                            .child(
 961                                                ui::Icon::new(ui::IconName::Warning)
 962                                                    .color(Color::Error)
 963                                                    .size(ui::IconSize::Small),
 964                                            )
 965                                            .child(
 966                                                Label::new("Authentication failed")
 967                                                    .color(Color::Error)
 968                                                    .size(LabelSize::Small),
 969                                            ),
 970                                    )
 971                                    .child(
 972                                        div().pl_6().child(
 973                                            Label::new(error)
 974                                                .color(Color::Error)
 975                                                .size(LabelSize::Small),
 976                                        ),
 977                                    ),
 978                            )
 979                        }),
 980                );
 981            }
 982
 983            // Render API key input if configured (and we have both options, show a separator)
 984            if has_api_key {
 985                if has_oauth {
 986                    content = content.child(
 987                        h_flex()
 988                            .gap_2()
 989                            .items_center()
 990                            .child(div().h_px().flex_1().bg(cx.theme().colors().border))
 991                            .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
 992                            .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
 993                    );
 994                }
 995
 996                let credential_label = self
 997                    .auth_config
 998                    .as_ref()
 999                    .and_then(|c| c.credential_label.clone())
1000                    .unwrap_or_else(|| "API Key".to_string());
1001
1002                content = content.child(
1003                    v_flex()
1004                        .gap_2()
1005                        .on_action(cx.listener(Self::save_api_key))
1006                        .child(
1007                            Label::new(credential_label)
1008                                .size(LabelSize::Small)
1009                                .color(Color::Muted),
1010                        )
1011                        .child(self.api_key_editor.clone())
1012                        .child(
1013                            Label::new("Enter your API key and press Enter to save")
1014                                .size(LabelSize::Small)
1015                                .color(Color::Muted),
1016                        ),
1017                );
1018            }
1019        }
1020
1021        content.into_any_element()
1022    }
1023}
1024
1025impl Focusable for ExtensionProviderConfigurationView {
1026    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1027        self.api_key_editor.focus_handle(cx)
1028    }
1029}
1030
1031fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1032    let theme_settings = ThemeSettings::get_global(cx);
1033    let colors = cx.theme().colors();
1034    let mut text_style = window.text_style();
1035    text_style.refine(&TextStyleRefinement {
1036        font_family: Some(theme_settings.ui_font.family.clone()),
1037        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1038        font_features: Some(theme_settings.ui_font.features.clone()),
1039        color: Some(colors.text),
1040        ..Default::default()
1041    });
1042
1043    MarkdownStyle {
1044        base_text_style: text_style,
1045        selection_background_color: colors.element_selection_background,
1046        inline_code: TextStyleRefinement {
1047            background_color: Some(colors.editor_background),
1048            ..Default::default()
1049        },
1050        link: TextStyleRefinement {
1051            color: Some(colors.text_accent),
1052            underline: Some(UnderlineStyle {
1053                color: Some(colors.text_accent.opacity(0.5)),
1054                thickness: px(1.),
1055                ..Default::default()
1056            }),
1057            ..Default::default()
1058        },
1059        syntax: cx.theme().syntax().clone(),
1060        ..Default::default()
1061    }
1062}
1063
1064/// An extension-based language model.
1065pub struct ExtensionLanguageModel {
1066    extension: WasmExtension,
1067    model_info: LlmModelInfo,
1068    provider_id: LanguageModelProviderId,
1069    provider_name: LanguageModelProviderName,
1070    provider_info: LlmProviderInfo,
1071}
1072
1073impl LanguageModel for ExtensionLanguageModel {
1074    fn id(&self) -> LanguageModelId {
1075        LanguageModelId::from(self.model_info.id.clone())
1076    }
1077
1078    fn name(&self) -> LanguageModelName {
1079        LanguageModelName::from(self.model_info.name.clone())
1080    }
1081
1082    fn provider_id(&self) -> LanguageModelProviderId {
1083        self.provider_id.clone()
1084    }
1085
1086    fn provider_name(&self) -> LanguageModelProviderName {
1087        self.provider_name.clone()
1088    }
1089
1090    fn telemetry_id(&self) -> String {
1091        format!("extension-{}", self.model_info.id)
1092    }
1093
1094    fn supports_images(&self) -> bool {
1095        self.model_info.capabilities.supports_images
1096    }
1097
1098    fn supports_tools(&self) -> bool {
1099        self.model_info.capabilities.supports_tools
1100    }
1101
1102    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1103        match choice {
1104            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1105            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1106            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1107        }
1108    }
1109
1110    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1111        match self.model_info.capabilities.tool_input_format {
1112            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1113            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1114        }
1115    }
1116
1117    fn max_token_count(&self) -> u64 {
1118        self.model_info.max_token_count
1119    }
1120
1121    fn max_output_tokens(&self) -> Option<u64> {
1122        self.model_info.max_output_tokens
1123    }
1124
1125    fn count_tokens(
1126        &self,
1127        request: LanguageModelRequest,
1128        cx: &App,
1129    ) -> BoxFuture<'static, Result<u64>> {
1130        let extension = self.extension.clone();
1131        let provider_id = self.provider_info.id.clone();
1132        let model_id = self.model_info.id.clone();
1133
1134        let wit_request = convert_request_to_wit(request);
1135
1136        cx.background_spawn(async move {
1137            extension
1138                .call({
1139                    let provider_id = provider_id.clone();
1140                    let model_id = model_id.clone();
1141                    let wit_request = wit_request.clone();
1142                    |ext, store| {
1143                        async move {
1144                            let count = ext
1145                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1146                                .await?
1147                                .map_err(|e| anyhow!("{}", e))?;
1148                            Ok(count)
1149                        }
1150                        .boxed()
1151                    }
1152                })
1153                .await?
1154        })
1155        .boxed()
1156    }
1157
1158    fn stream_completion(
1159        &self,
1160        request: LanguageModelRequest,
1161        _cx: &AsyncApp,
1162    ) -> BoxFuture<
1163        'static,
1164        Result<
1165            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1166            LanguageModelCompletionError,
1167        >,
1168    > {
1169        let extension = self.extension.clone();
1170        let provider_id = self.provider_info.id.clone();
1171        let model_id = self.model_info.id.clone();
1172
1173        let wit_request = convert_request_to_wit(request);
1174
1175        async move {
1176            // Start the stream
1177            let stream_id_result = extension
1178                .call({
1179                    let provider_id = provider_id.clone();
1180                    let model_id = model_id.clone();
1181                    let wit_request = wit_request.clone();
1182                    |ext, store| {
1183                        async move {
1184                            let id = ext
1185                                .call_llm_stream_completion_start(
1186                                    store,
1187                                    &provider_id,
1188                                    &model_id,
1189                                    &wit_request,
1190                                )
1191                                .await?
1192                                .map_err(|e| anyhow!("{}", e))?;
1193                            Ok(id)
1194                        }
1195                        .boxed()
1196                    }
1197                })
1198                .await;
1199
1200            let stream_id = stream_id_result
1201                .map_err(LanguageModelCompletionError::Other)?
1202                .map_err(LanguageModelCompletionError::Other)?;
1203
1204            // Create a stream that polls for events
1205            let stream = futures::stream::unfold(
1206                (extension.clone(), stream_id, false),
1207                move |(extension, stream_id, done)| async move {
1208                    if done {
1209                        return None;
1210                    }
1211
1212                    let result = extension
1213                        .call({
1214                            let stream_id = stream_id.clone();
1215                            |ext, store| {
1216                                async move {
1217                                    let event = ext
1218                                        .call_llm_stream_completion_next(store, &stream_id)
1219                                        .await?
1220                                        .map_err(|e| anyhow!("{}", e))?;
1221                                    Ok(event)
1222                                }
1223                                .boxed()
1224                            }
1225                        })
1226                        .await
1227                        .and_then(|inner| inner);
1228
1229                    match result {
1230                        Ok(Some(event)) => {
1231                            let converted = convert_completion_event(event);
1232                            let is_done =
1233                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1234                            Some((converted, (extension, stream_id, is_done)))
1235                        }
1236                        Ok(None) => {
1237                            // Stream complete, close it
1238                            let _ = extension
1239                                .call({
1240                                    let stream_id = stream_id.clone();
1241                                    |ext, store| {
1242                                        async move {
1243                                            ext.call_llm_stream_completion_close(store, &stream_id)
1244                                                .await?;
1245                                            Ok::<(), anyhow::Error>(())
1246                                        }
1247                                        .boxed()
1248                                    }
1249                                })
1250                                .await;
1251                            None
1252                        }
1253                        Err(e) => Some((
1254                            Err(LanguageModelCompletionError::Other(e)),
1255                            (extension, stream_id, true),
1256                        )),
1257                    }
1258                },
1259            );
1260
1261            Ok(stream.boxed())
1262        }
1263        .boxed()
1264    }
1265
1266    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1267        // Extensions can implement this via llm_cache_configuration
1268        None
1269    }
1270}
1271
1272fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1273    use language_model::{MessageContent, Role};
1274
1275    let messages: Vec<LlmRequestMessage> = request
1276        .messages
1277        .into_iter()
1278        .map(|msg| {
1279            let role = match msg.role {
1280                Role::User => LlmMessageRole::User,
1281                Role::Assistant => LlmMessageRole::Assistant,
1282                Role::System => LlmMessageRole::System,
1283            };
1284
1285            let content: Vec<LlmMessageContent> = msg
1286                .content
1287                .into_iter()
1288                .map(|c| match c {
1289                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1290                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1291                        source: image.source.to_string(),
1292                        width: Some(image.size.width.0 as u32),
1293                        height: Some(image.size.height.0 as u32),
1294                    }),
1295                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1296                        id: tool_use.id.to_string(),
1297                        name: tool_use.name.to_string(),
1298                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1299                        thought_signature: tool_use.thought_signature,
1300                    }),
1301                    MessageContent::ToolResult(tool_result) => {
1302                        let content = match tool_result.content {
1303                            language_model::LanguageModelToolResultContent::Text(text) => {
1304                                LlmToolResultContent::Text(text.to_string())
1305                            }
1306                            language_model::LanguageModelToolResultContent::Image(image) => {
1307                                LlmToolResultContent::Image(LlmImageData {
1308                                    source: image.source.to_string(),
1309                                    width: Some(image.size.width.0 as u32),
1310                                    height: Some(image.size.height.0 as u32),
1311                                })
1312                            }
1313                        };
1314                        LlmMessageContent::ToolResult(LlmToolResult {
1315                            tool_use_id: tool_result.tool_use_id.to_string(),
1316                            tool_name: tool_result.tool_name.to_string(),
1317                            is_error: tool_result.is_error,
1318                            content,
1319                        })
1320                    }
1321                    MessageContent::Thinking { text, signature } => {
1322                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1323                    }
1324                    MessageContent::RedactedThinking(data) => {
1325                        LlmMessageContent::RedactedThinking(data)
1326                    }
1327                })
1328                .collect();
1329
1330            LlmRequestMessage {
1331                role,
1332                content,
1333                cache: msg.cache,
1334            }
1335        })
1336        .collect();
1337
1338    let tools: Vec<LlmToolDefinition> = request
1339        .tools
1340        .into_iter()
1341        .map(|tool| LlmToolDefinition {
1342            name: tool.name,
1343            description: tool.description,
1344            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1345        })
1346        .collect();
1347
1348    let tool_choice = request.tool_choice.map(|tc| match tc {
1349        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1350        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1351        LanguageModelToolChoice::None => LlmToolChoice::None,
1352    });
1353
1354    LlmCompletionRequest {
1355        messages,
1356        tools,
1357        tool_choice,
1358        stop_sequences: request.stop,
1359        temperature: request.temperature,
1360        thinking_allowed: false,
1361        max_tokens: None,
1362    }
1363}
1364
1365fn convert_completion_event(
1366    event: LlmCompletionEvent,
1367) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1368    match event {
1369        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1370            message_id: String::new(),
1371        }),
1372        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1373        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1374            text: thinking.text,
1375            signature: thinking.signature,
1376        }),
1377        LlmCompletionEvent::RedactedThinking(data) => {
1378            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1379        }
1380        LlmCompletionEvent::ToolUse(tool_use) => {
1381            let raw_input = tool_use.input.clone();
1382            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1383            Ok(LanguageModelCompletionEvent::ToolUse(
1384                LanguageModelToolUse {
1385                    id: LanguageModelToolUseId::from(tool_use.id),
1386                    name: tool_use.name.into(),
1387                    raw_input,
1388                    input,
1389                    is_input_complete: true,
1390                    thought_signature: tool_use.thought_signature,
1391                },
1392            ))
1393        }
1394        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1395            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1396                id: LanguageModelToolUseId::from(error.id),
1397                tool_name: error.tool_name.into(),
1398                raw_input: error.raw_input.into(),
1399                json_parse_error: error.error,
1400            })
1401        }
1402        LlmCompletionEvent::Stop(reason) => {
1403            let stop_reason = match reason {
1404                LlmStopReason::EndTurn => StopReason::EndTurn,
1405                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1406                LlmStopReason::ToolUse => StopReason::ToolUse,
1407                LlmStopReason::Refusal => StopReason::Refusal,
1408            };
1409            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1410        }
1411        LlmCompletionEvent::Usage(usage) => {
1412            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1413                input_tokens: usage.input_tokens,
1414                output_tokens: usage.output_tokens,
1415                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1416                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1417            }))
1418        }
1419        LlmCompletionEvent::ReasoningDetails(json) => {
1420            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1421                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1422            ))
1423        }
1424    }
1425}