llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::LEGACY_LLM_EXTENSION_IDS;
   3use crate::wasm_host::WasmExtension;
   4use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
   5use collections::HashSet;
   6
   7use crate::wasm_host::wit::{
   8    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   9    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
  10    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  11    LlmToolUse,
  12};
  13use anyhow::{Result, anyhow};
  14use credentials_provider::CredentialsProvider;
  15use editor::Editor;
  16use extension::{LanguageModelAuthConfig, OAuthConfig};
  17use futures::future::BoxFuture;
  18use futures::stream::BoxStream;
  19use futures::{FutureExt, StreamExt};
  20use gpui::Focusable;
  21use gpui::{
  22    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
  23    TextStyleRefinement, UnderlineStyle, Window, px,
  24};
  25use language_model::tool_schema::LanguageModelToolSchemaFormat;
  26use language_model::{
  27    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  28    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  29    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  30    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  31    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  32};
  33use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  34use settings::Settings;
  35use std::sync::Arc;
  36use theme::ThemeSettings;
  37use ui::{Label, LabelSize, prelude::*};
  38use util::ResultExt as _;
  39use workspace::Workspace;
  40use workspace::oauth_device_flow_modal::{
  41    OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
  42};
  43
  44/// An extension-based language model provider.
  45pub struct ExtensionLanguageModelProvider {
  46    pub extension: WasmExtension,
  47    pub provider_info: LlmProviderInfo,
  48    icon_path: Option<SharedString>,
  49    auth_config: Option<LanguageModelAuthConfig>,
  50    state: Entity<ExtensionLlmProviderState>,
  51}
  52
  53pub struct ExtensionLlmProviderState {
  54    is_authenticated: bool,
  55    available_models: Vec<LlmModelInfo>,
  56    /// Set of env var names that are allowed to be read for this provider.
  57    allowed_env_vars: HashSet<String>,
  58    /// If authenticated via env var, which one was used.
  59    env_var_name_used: Option<String>,
  60}
  61
  62impl EventEmitter<()> for ExtensionLlmProviderState {}
  63
  64impl ExtensionLanguageModelProvider {
  65    pub fn new(
  66        extension: WasmExtension,
  67        provider_info: LlmProviderInfo,
  68        models: Vec<LlmModelInfo>,
  69        is_authenticated: bool,
  70        icon_path: Option<SharedString>,
  71        auth_config: Option<LanguageModelAuthConfig>,
  72        cx: &mut App,
  73    ) -> Self {
  74        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  75
  76        // Build set of allowed env vars for this provider
  77        let settings = ExtensionSettings::get_global(cx);
  78        let is_legacy_extension =
  79            LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
  80
  81        let mut allowed_env_vars = HashSet::default();
  82        if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
  83            for env_var_name in env_vars {
  84                let key = format!("{}:{}", provider_id_string, env_var_name);
  85                // For legacy extensions, auto-allow if env var is set (migration will persist this)
  86                let env_var_is_set = std::env::var(env_var_name)
  87                    .map(|v| !v.is_empty())
  88                    .unwrap_or(false);
  89                if settings.allowed_env_var_providers.contains(key.as_str())
  90                    || (is_legacy_extension && env_var_is_set)
  91                {
  92                    allowed_env_vars.insert(env_var_name.clone());
  93                }
  94            }
  95        }
  96
  97        // Check if any allowed env var is set
  98        let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
  99            if let Ok(value) = std::env::var(env_var_name) {
 100                if !value.is_empty() {
 101                    return Some(env_var_name.clone());
 102                }
 103            }
 104            None
 105        });
 106
 107        let is_authenticated = if env_var_name_used.is_some() {
 108            true
 109        } else {
 110            is_authenticated
 111        };
 112
 113        let state = cx.new(|_| ExtensionLlmProviderState {
 114            is_authenticated,
 115            available_models: models,
 116            allowed_env_vars,
 117            env_var_name_used,
 118        });
 119
 120        Self {
 121            extension,
 122            provider_info,
 123            icon_path,
 124            auth_config,
 125            state,
 126        }
 127    }
 128
 129    fn provider_id_string(&self) -> String {
 130        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 131    }
 132
 133    /// The credential key used for storing the API key in the system keychain.
 134    fn credential_key(&self) -> String {
 135        format!("extension-llm-{}", self.provider_id_string())
 136    }
 137}
 138
 139impl LanguageModelProvider for ExtensionLanguageModelProvider {
 140    fn id(&self) -> LanguageModelProviderId {
 141        LanguageModelProviderId::from(self.provider_id_string())
 142    }
 143
 144    fn name(&self) -> LanguageModelProviderName {
 145        LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
 146    }
 147
 148    fn icon(&self) -> ui::IconName {
 149        ui::IconName::ZedAssistant
 150    }
 151
 152    fn icon_path(&self) -> Option<SharedString> {
 153        self.icon_path.clone()
 154    }
 155
 156    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 157        let state = self.state.read(cx);
 158        state
 159            .available_models
 160            .iter()
 161            .find(|m| m.is_default)
 162            .or_else(|| state.available_models.first())
 163            .map(|model_info| {
 164                Arc::new(ExtensionLanguageModel {
 165                    extension: self.extension.clone(),
 166                    model_info: model_info.clone(),
 167                    provider_id: self.id(),
 168                    provider_name: self.name(),
 169                    provider_info: self.provider_info.clone(),
 170                }) as Arc<dyn LanguageModel>
 171            })
 172    }
 173
 174    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 175        let state = self.state.read(cx);
 176        state
 177            .available_models
 178            .iter()
 179            .find(|m| m.is_default_fast)
 180            .map(|model_info| {
 181                Arc::new(ExtensionLanguageModel {
 182                    extension: self.extension.clone(),
 183                    model_info: model_info.clone(),
 184                    provider_id: self.id(),
 185                    provider_name: self.name(),
 186                    provider_info: self.provider_info.clone(),
 187                }) as Arc<dyn LanguageModel>
 188            })
 189    }
 190
 191    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 192        let state = self.state.read(cx);
 193        state
 194            .available_models
 195            .iter()
 196            .map(|model_info| {
 197                Arc::new(ExtensionLanguageModel {
 198                    extension: self.extension.clone(),
 199                    model_info: model_info.clone(),
 200                    provider_id: self.id(),
 201                    provider_name: self.name(),
 202                    provider_info: self.provider_info.clone(),
 203                }) as Arc<dyn LanguageModel>
 204            })
 205            .collect()
 206    }
 207
 208    fn is_authenticated(&self, cx: &App) -> bool {
 209        // First check cached state
 210        if self.state.read(cx).is_authenticated {
 211            return true;
 212        }
 213
 214        // Also check env var dynamically (in case settings changed after provider creation)
 215        if let Some(ref auth_config) = self.auth_config {
 216            if let Some(ref env_vars) = auth_config.env_vars {
 217                let provider_id_string = self.provider_id_string();
 218                let settings = ExtensionSettings::get_global(cx);
 219                let is_legacy_extension =
 220                    LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
 221
 222                for env_var_name in env_vars {
 223                    let key = format!("{}:{}", provider_id_string, env_var_name);
 224                    // For legacy extensions, auto-allow if env var is set
 225                    let env_var_is_set = std::env::var(env_var_name)
 226                        .map(|v| !v.is_empty())
 227                        .unwrap_or(false);
 228                    if settings.allowed_env_var_providers.contains(key.as_str())
 229                        || (is_legacy_extension && env_var_is_set)
 230                    {
 231                        if let Ok(value) = std::env::var(env_var_name) {
 232                            if !value.is_empty() {
 233                                return true;
 234                            }
 235                        }
 236                    }
 237                }
 238            }
 239        }
 240
 241        false
 242    }
 243
 244    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 245        // Check if already authenticated via is_authenticated
 246        if self.is_authenticated(cx) {
 247            return Task::ready(Ok(()));
 248        }
 249
 250        // Not authenticated - return error indicating credentials not found
 251        Task::ready(Err(AuthenticateError::CredentialsNotFound))
 252    }
 253
 254    fn configuration_view(
 255        &self,
 256        _target_agent: ConfigurationViewTargetAgent,
 257        window: &mut Window,
 258        cx: &mut App,
 259    ) -> AnyView {
 260        let credential_key = self.credential_key();
 261        let extension = self.extension.clone();
 262        let extension_provider_id = self.provider_info.id.clone();
 263        let full_provider_id = self.provider_id_string();
 264        let state = self.state.clone();
 265        let auth_config = self.auth_config.clone();
 266
 267        let icon_path = self.icon_path.clone();
 268        cx.new(|cx| {
 269            ExtensionProviderConfigurationView::new(
 270                credential_key,
 271                extension,
 272                extension_provider_id,
 273                full_provider_id,
 274                auth_config,
 275                state,
 276                icon_path,
 277                window,
 278                cx,
 279            )
 280        })
 281        .into()
 282    }
 283
 284    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 285        let extension = self.extension.clone();
 286        let provider_id = self.provider_info.id.clone();
 287        let state = self.state.clone();
 288        let credential_key = self.credential_key();
 289
 290        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 291
 292        cx.spawn(async move |cx| {
 293            // Delete from system keychain
 294            credentials_provider
 295                .delete_credentials(&credential_key, cx)
 296                .await
 297                .log_err();
 298
 299            // Call extension's reset_credentials
 300            let result = extension
 301                .call(|extension, store| {
 302                    async move {
 303                        extension
 304                            .call_llm_provider_reset_credentials(store, &provider_id)
 305                            .await
 306                    }
 307                    .boxed()
 308                })
 309                .await;
 310
 311            // Update state
 312            cx.update(|cx| {
 313                state.update(cx, |state, _| {
 314                    state.is_authenticated = false;
 315                });
 316            })?;
 317
 318            match result {
 319                Ok(Ok(Ok(()))) => Ok(()),
 320                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 321                Ok(Err(e)) => Err(e),
 322                Err(e) => Err(e),
 323            }
 324        })
 325    }
 326}
 327
 328impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 329    type ObservableEntity = ExtensionLlmProviderState;
 330
 331    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 332        Some(self.state.clone())
 333    }
 334
 335    fn subscribe<T: 'static>(
 336        &self,
 337        cx: &mut Context<T>,
 338        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 339    ) -> Option<Subscription> {
 340        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 341    }
 342}
 343
 344/// Configuration view for extension-based LLM providers.
 345struct ExtensionProviderConfigurationView {
 346    credential_key: String,
 347    extension: WasmExtension,
 348    extension_provider_id: String,
 349    full_provider_id: String,
 350    auth_config: Option<LanguageModelAuthConfig>,
 351    state: Entity<ExtensionLlmProviderState>,
 352    settings_markdown: Option<Entity<Markdown>>,
 353    api_key_editor: Entity<Editor>,
 354    loading_settings: bool,
 355    loading_credentials: bool,
 356    oauth_in_progress: bool,
 357    oauth_error: Option<String>,
 358    icon_path: Option<SharedString>,
 359    _subscriptions: Vec<Subscription>,
 360}
 361
 362impl ExtensionProviderConfigurationView {
 363    fn new(
 364        credential_key: String,
 365        extension: WasmExtension,
 366        extension_provider_id: String,
 367        full_provider_id: String,
 368        auth_config: Option<LanguageModelAuthConfig>,
 369        state: Entity<ExtensionLlmProviderState>,
 370        icon_path: Option<SharedString>,
 371        window: &mut Window,
 372        cx: &mut Context<Self>,
 373    ) -> Self {
 374        // Subscribe to state changes
 375        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 376            cx.notify();
 377        });
 378
 379        // Create API key editor
 380        let api_key_editor = cx.new(|cx| {
 381            let mut editor = Editor::single_line(window, cx);
 382            editor.set_placeholder_text("Enter API key...", window, cx);
 383            editor
 384        });
 385
 386        let mut this = Self {
 387            credential_key,
 388            extension,
 389            extension_provider_id,
 390            full_provider_id,
 391            auth_config,
 392            state,
 393            settings_markdown: None,
 394            api_key_editor,
 395            loading_settings: true,
 396            loading_credentials: true,
 397            oauth_in_progress: false,
 398            oauth_error: None,
 399            icon_path,
 400            _subscriptions: vec![state_subscription],
 401        };
 402
 403        // Load settings text from extension
 404        this.load_settings_text(cx);
 405
 406        // Load existing credentials
 407        this.load_credentials(cx);
 408
 409        this
 410    }
 411
 412    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 413        let extension = self.extension.clone();
 414        let provider_id = self.extension_provider_id.clone();
 415
 416        cx.spawn(async move |this, cx| {
 417            let result = extension
 418                .call({
 419                    let provider_id = provider_id.clone();
 420                    |ext, store| {
 421                        async move {
 422                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 423                                .await
 424                        }
 425                        .boxed()
 426                    }
 427                })
 428                .await;
 429
 430            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 431
 432            this.update(cx, |this, cx| {
 433                this.loading_settings = false;
 434                if let Some(text) = settings_text {
 435                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 436                    this.settings_markdown = Some(markdown);
 437                }
 438                cx.notify();
 439            })
 440            .log_err();
 441        })
 442        .detach();
 443    }
 444
 445    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 446        let credential_key = self.credential_key.clone();
 447        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 448        let state = self.state.clone();
 449
 450        // Check if we should use env var (already set in state during provider construction)
 451        let using_env_var = self.state.read(cx).env_var_name_used.is_some();
 452
 453        cx.spawn(async move |this, cx| {
 454            // If using env var, we're already authenticated
 455            if using_env_var {
 456                this.update(cx, |this, cx| {
 457                    this.loading_credentials = false;
 458                    cx.notify();
 459                })
 460                .log_err();
 461                return;
 462            }
 463
 464            let credentials = credentials_provider
 465                .read_credentials(&credential_key, cx)
 466                .await
 467                .log_err()
 468                .flatten();
 469
 470            let has_credentials = credentials.is_some();
 471
 472            // Update authentication state based on stored credentials
 473            cx.update(|cx| {
 474                state.update(cx, |state, cx| {
 475                    state.is_authenticated = has_credentials;
 476                    cx.notify();
 477                });
 478            })
 479            .log_err();
 480
 481            this.update(cx, |this, cx| {
 482                this.loading_credentials = false;
 483                cx.notify();
 484            })
 485            .log_err();
 486        })
 487        .detach();
 488    }
 489
 490    fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
 491        let full_provider_id = self.full_provider_id.clone();
 492        let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
 493
 494        let state = self.state.clone();
 495        let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
 496
 497        // Update settings file
 498        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
 499            move |settings, _| {
 500                let allowed = settings
 501                    .extension
 502                    .allowed_env_var_providers
 503                    .get_or_insert_with(Vec::new);
 504
 505                if currently_allowed {
 506                    allowed.retain(|id| id.as_ref() != settings_key.as_ref());
 507                } else {
 508                    if !allowed
 509                        .iter()
 510                        .any(|id| id.as_ref() == settings_key.as_ref())
 511                    {
 512                        allowed.push(settings_key.clone());
 513                    }
 514                }
 515            }
 516        });
 517
 518        // Update local state
 519        let new_allowed = !currently_allowed;
 520
 521        state.update(cx, |state, cx| {
 522            if new_allowed {
 523                state.allowed_env_vars.insert(env_var_name.clone());
 524                // Check if this env var is set and update env_var_name_used
 525                if let Ok(value) = std::env::var(&env_var_name) {
 526                    if !value.is_empty() && state.env_var_name_used.is_none() {
 527                        state.env_var_name_used = Some(env_var_name.clone());
 528                        state.is_authenticated = true;
 529                    }
 530                }
 531            } else {
 532                state.allowed_env_vars.remove(&env_var_name);
 533                // If this was the env var being used, clear it and find another
 534                if state.env_var_name_used.as_ref() == Some(&env_var_name) {
 535                    state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
 536                        if let Ok(value) = std::env::var(var) {
 537                            if !value.is_empty() {
 538                                return Some(var.clone());
 539                            }
 540                        }
 541                        None
 542                    });
 543                    if state.env_var_name_used.is_none() {
 544                        // No env var auth available, need to check keychain
 545                        state.is_authenticated = false;
 546                    }
 547                }
 548            }
 549            cx.notify();
 550        });
 551
 552        // If all env vars are being disabled, reload credentials from keychain
 553        if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
 554            self.reload_keychain_credentials(cx);
 555        }
 556
 557        cx.notify();
 558    }
 559
 560    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 561        let credential_key = self.credential_key.clone();
 562        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 563        let state = self.state.clone();
 564
 565        cx.spawn(async move |_this, cx| {
 566            let credentials = credentials_provider
 567                .read_credentials(&credential_key, cx)
 568                .await
 569                .log_err()
 570                .flatten();
 571
 572            let has_credentials = credentials.is_some();
 573
 574            cx.update(|cx| {
 575                state.update(cx, |state, cx| {
 576                    state.is_authenticated = has_credentials;
 577                    cx.notify();
 578                });
 579            })
 580            .log_err();
 581        })
 582        .detach();
 583    }
 584
 585    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 586        let api_key = self.api_key_editor.read(cx).text(cx);
 587        if api_key.is_empty() {
 588            return;
 589        }
 590
 591        // Clear the editor
 592        self.api_key_editor
 593            .update(cx, |editor, cx| editor.set_text("", window, cx));
 594
 595        let credential_key = self.credential_key.clone();
 596        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 597        let state = self.state.clone();
 598
 599        cx.spawn(async move |_this, cx| {
 600            // Store in system keychain
 601            credentials_provider
 602                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 603                .await
 604                .log_err();
 605
 606            // Update state to authenticated
 607            cx.update(|cx| {
 608                state.update(cx, |state, cx| {
 609                    state.is_authenticated = true;
 610                    cx.notify();
 611                });
 612            })
 613            .log_err();
 614        })
 615        .detach();
 616    }
 617
 618    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 619        // Clear the editor
 620        self.api_key_editor
 621            .update(cx, |editor, cx| editor.set_text("", window, cx));
 622
 623        let credential_key = self.credential_key.clone();
 624        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 625        let state = self.state.clone();
 626
 627        cx.spawn(async move |_this, cx| {
 628            // Delete from system keychain
 629            credentials_provider
 630                .delete_credentials(&credential_key, cx)
 631                .await
 632                .log_err();
 633
 634            // Update state to unauthenticated
 635            cx.update(|cx| {
 636                state.update(cx, |state, cx| {
 637                    state.is_authenticated = false;
 638                    cx.notify();
 639                });
 640            })
 641            .log_err();
 642        })
 643        .detach();
 644    }
 645
 646    fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 647        if self.oauth_in_progress {
 648            return;
 649        }
 650
 651        self.oauth_in_progress = true;
 652        self.oauth_error = None;
 653        cx.notify();
 654
 655        let extension = self.extension.clone();
 656        let provider_id = self.extension_provider_id.clone();
 657        let state = self.state.clone();
 658        let icon_path = self.icon_path.clone();
 659        let this_handle = cx.weak_entity();
 660
 661        // Get workspace to show modal
 662        let Some(workspace) = window.root::<Workspace>().flatten() else {
 663            self.oauth_in_progress = false;
 664            self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string());
 665            cx.notify();
 666            return;
 667        };
 668
 669        let workspace = workspace.downgrade();
 670        let state = state.downgrade();
 671        cx.spawn_in(window, async move |_this, cx| {
 672            // Step 1: Start device flow - get prompt info from extension
 673            let start_result = extension
 674                .call({
 675                    let provider_id = provider_id.clone();
 676                    |ext, store| {
 677                        async move {
 678                            ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
 679                                .await
 680                        }
 681                        .boxed()
 682                    }
 683                })
 684                .await;
 685
 686            let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
 687                Ok(Ok(Ok(info))) => info,
 688                Ok(Ok(Err(e))) => {
 689                    log::error!("Device flow start failed: {}", e);
 690                    this_handle
 691                        .update_in(cx, |this, _window, cx| {
 692                            this.oauth_in_progress = false;
 693                            this.oauth_error = Some(e);
 694                            cx.notify();
 695                        })
 696                        .log_err();
 697                    return;
 698                }
 699                Ok(Err(e)) | Err(e) => {
 700                    log::error!("Device flow start error: {}", e);
 701                    this_handle
 702                        .update_in(cx, |this, _window, cx| {
 703                            this.oauth_in_progress = false;
 704                            this.oauth_error = Some(e.to_string());
 705                            cx.notify();
 706                        })
 707                        .log_err();
 708                    return;
 709                }
 710            };
 711
 712            // Step 2: Create state entity and show the modal
 713            let modal_config = OAuthDeviceFlowModalConfig {
 714                user_code: prompt_info.user_code,
 715                verification_url: prompt_info.verification_url,
 716                headline: prompt_info.headline,
 717                description: prompt_info.description,
 718                connect_button_label: prompt_info.connect_button_label,
 719                success_headline: prompt_info.success_headline,
 720                success_message: prompt_info.success_message,
 721                icon_path,
 722            };
 723
 724            let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace
 725                .update_in(cx, |workspace, window, cx| {
 726                    let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
 727                    let flow_state_clone = flow_state.clone();
 728                    workspace.toggle_modal(window, cx, |_window, cx| {
 729                        OAuthDeviceFlowModal::new(flow_state_clone, cx)
 730                    });
 731                    flow_state
 732                })
 733                .ok();
 734
 735            let Some(flow_state) = flow_state else {
 736                this_handle
 737                    .update_in(cx, |this, _window, cx| {
 738                        this.oauth_in_progress = false;
 739                        this.oauth_error = Some("Failed to show sign-in modal".to_string());
 740                        cx.notify();
 741                    })
 742                    .log_err();
 743                return;
 744            };
 745
 746            // Step 3: Poll for authentication completion
 747            let poll_result = extension
 748                .call({
 749                    let provider_id = provider_id.clone();
 750                    |ext, store| {
 751                        async move {
 752                            ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
 753                                .await
 754                        }
 755                        .boxed()
 756                    }
 757                })
 758                .await;
 759
 760            match poll_result {
 761                Ok(Ok(Ok(()))) => {
 762                    // After successful auth, refresh the models list
 763                    let models_result = extension
 764                        .call({
 765                            let provider_id = provider_id.clone();
 766                            |ext, store| {
 767                                async move {
 768                                    ext.call_llm_provider_models(store, &provider_id).await
 769                                }
 770                                .boxed()
 771                            }
 772                        })
 773                        .await;
 774
 775                    let new_models: Vec<LlmModelInfo> = match models_result {
 776                        Ok(Ok(Ok(models))) => models,
 777                        _ => Vec::new(),
 778                    };
 779
 780                    state
 781                        .update_in(cx, |state, _window, cx| {
 782                            state.is_authenticated = true;
 783                            state.available_models = new_models;
 784                            cx.notify();
 785                        })
 786                        .log_err();
 787
 788                    // Update flow state to show success
 789                    flow_state
 790                        .update_in(cx, |state, _window, cx| {
 791                            state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
 792                        })
 793                        .log_err();
 794                }
 795                Ok(Ok(Err(e))) => {
 796                    log::error!("Device flow poll failed: {}", e);
 797                    flow_state
 798                        .update_in(cx, |state, _window, cx| {
 799                            state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
 800                        })
 801                        .log_err();
 802                    this_handle
 803                        .update_in(cx, |this, _window, cx| {
 804                            this.oauth_error = Some(e);
 805                            cx.notify();
 806                        })
 807                        .log_err();
 808                }
 809                Ok(Err(e)) | Err(e) => {
 810                    log::error!("Device flow poll error: {}", e);
 811                    let error_string = e.to_string();
 812                    flow_state
 813                        .update_in(cx, |state, _window, cx| {
 814                            state.set_status(
 815                                OAuthDeviceFlowStatus::Failed(error_string.clone()),
 816                                cx,
 817                            );
 818                        })
 819                        .log_err();
 820                    this_handle
 821                        .update_in(cx, |this, _window, cx| {
 822                            this.oauth_error = Some(error_string);
 823                            cx.notify();
 824                        })
 825                        .log_err();
 826                }
 827            };
 828
 829            this_handle
 830                .update_in(cx, |this, _window, cx| {
 831                    this.oauth_in_progress = false;
 832                    cx.notify();
 833                })
 834                .log_err();
 835        })
 836        .detach();
 837    }
 838
 839    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 840        self.state.read(cx).is_authenticated
 841    }
 842
 843    fn has_oauth_config(&self) -> bool {
 844        self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
 845    }
 846
 847    fn oauth_config(&self) -> Option<&OAuthConfig> {
 848        self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
 849    }
 850
 851    fn has_api_key_config(&self) -> bool {
 852        // API key is available if there's a credential_label or no oauth-only config
 853        self.auth_config
 854            .as_ref()
 855            .map(|c| c.credential_label.is_some() || c.oauth.is_none())
 856            .unwrap_or(true)
 857    }
 858}
 859
 860impl gpui::Render for ExtensionProviderConfigurationView {
 861    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 862        let is_loading = self.loading_settings || self.loading_credentials;
 863        let is_authenticated = self.is_authenticated(cx);
 864        let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
 865        let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
 866        let has_oauth = self.has_oauth_config();
 867        let has_api_key = self.has_api_key_config();
 868
 869        if is_loading {
 870            return v_flex()
 871                .gap_2()
 872                .child(Label::new("Loading...").color(Color::Muted))
 873                .into_any_element();
 874        }
 875
 876        let mut content = v_flex().gap_4().size_full();
 877
 878        // Render settings markdown if available
 879        if let Some(markdown) = &self.settings_markdown {
 880            let style = settings_markdown_style(window, cx);
 881            content = content.child(MarkdownElement::new(markdown.clone(), style));
 882        }
 883
 884        // Render env var checkboxes - one for each env var the extension declares
 885        if let Some(auth_config) = &self.auth_config {
 886            if let Some(env_vars) = &auth_config.env_vars {
 887                for env_var_name in env_vars {
 888                    let is_allowed = allowed_env_vars.contains(env_var_name);
 889                    let checkbox_label =
 890                        format!("Read API key from {} environment variable", env_var_name);
 891                    let env_var_for_click = env_var_name.clone();
 892
 893                    content = content.child(
 894                        h_flex()
 895                            .gap_2()
 896                            .child(
 897                                ui::Checkbox::new(
 898                                    SharedString::from(format!("env-var-{}", env_var_name)),
 899                                    is_allowed.into(),
 900                                )
 901                                .on_click(cx.listener(
 902                                    move |this, _, _window, cx| {
 903                                        this.toggle_env_var_permission(
 904                                            env_var_for_click.clone(),
 905                                            cx,
 906                                        );
 907                                    },
 908                                )),
 909                            )
 910                            .child(Label::new(checkbox_label).size(LabelSize::Small)),
 911                    );
 912                }
 913
 914                // Show status if any env var is being used
 915                if let Some(used_var) = &env_var_name_used {
 916                    let tooltip_label = format!(
 917                        "To reset this API key, unset the {} environment variable.",
 918                        used_var
 919                    );
 920                    content = content.child(
 921                        h_flex()
 922                            .mt_0p5()
 923                            .p_1()
 924                            .justify_between()
 925                            .rounded_md()
 926                            .border_1()
 927                            .border_color(cx.theme().colors().border)
 928                            .bg(cx.theme().colors().background)
 929                            .child(
 930                                h_flex()
 931                                    .flex_1()
 932                                    .min_w_0()
 933                                    .gap_1()
 934                                    .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
 935                                    .child(
 936                                        Label::new(format!(
 937                                            "API key set in {} environment variable",
 938                                            used_var
 939                                        ))
 940                                        .truncate(),
 941                                    ),
 942                            )
 943                            .child(
 944                                ui::Button::new("reset-key", "Reset Key")
 945                                    .label_size(LabelSize::Small)
 946                                    .icon(ui::IconName::Undo)
 947                                    .icon_size(ui::IconSize::Small)
 948                                    .icon_color(Color::Muted)
 949                                    .icon_position(ui::IconPosition::Start)
 950                                    .disabled(true)
 951                                    .tooltip(ui::Tooltip::text(tooltip_label)),
 952                            ),
 953                    );
 954                    return content.into_any_element();
 955                }
 956            }
 957        }
 958
 959        // If authenticated, show success state with sign out option
 960        if is_authenticated && env_var_name_used.is_none() {
 961            let reset_label = if has_oauth && !has_api_key {
 962                "Sign Out"
 963            } else {
 964                "Reset Key"
 965            };
 966
 967            let status_label = if has_oauth && !has_api_key {
 968                "Signed in"
 969            } else {
 970                "API key configured"
 971            };
 972
 973            content = content.child(
 974                h_flex()
 975                    .mt_0p5()
 976                    .p_1()
 977                    .justify_between()
 978                    .rounded_md()
 979                    .border_1()
 980                    .border_color(cx.theme().colors().border)
 981                    .bg(cx.theme().colors().background)
 982                    .child(
 983                        h_flex()
 984                            .flex_1()
 985                            .min_w_0()
 986                            .gap_1()
 987                            .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
 988                            .child(Label::new(status_label).truncate()),
 989                    )
 990                    .child(
 991                        ui::Button::new("reset-key", reset_label)
 992                            .label_size(LabelSize::Small)
 993                            .icon(ui::IconName::Undo)
 994                            .icon_size(ui::IconSize::Small)
 995                            .icon_color(Color::Muted)
 996                            .icon_position(ui::IconPosition::Start)
 997                            .on_click(cx.listener(|this, _, window, cx| {
 998                                this.reset_api_key(window, cx);
 999                            })),
1000                    ),
1001            );
1002
1003            return content.into_any_element();
1004        }
1005
1006        // Not authenticated - show available auth options
1007        if env_var_name_used.is_none() {
1008            // Render OAuth sign-in button if configured
1009            if has_oauth {
1010                let oauth_config = self.oauth_config();
1011                let button_label = oauth_config
1012                    .and_then(|c| c.sign_in_button_label.clone())
1013                    .unwrap_or_else(|| "Sign In".to_string());
1014                let button_icon = oauth_config
1015                    .and_then(|c| c.sign_in_button_icon.as_ref())
1016                    .and_then(|icon_name| match icon_name.as_str() {
1017                        "github" => Some(ui::IconName::Github),
1018                        _ => None,
1019                    });
1020
1021                let oauth_in_progress = self.oauth_in_progress;
1022
1023                let oauth_error = self.oauth_error.clone();
1024
1025                let mut button = ui::Button::new("oauth-sign-in", button_label)
1026                    .full_width()
1027                    .style(ui::ButtonStyle::Outlined)
1028                    .disabled(oauth_in_progress)
1029                    .on_click(cx.listener(|this, _, window, cx| {
1030                        this.start_oauth_sign_in(window, cx);
1031                    }));
1032                if let Some(icon) = button_icon {
1033                    button = button
1034                        .icon(icon)
1035                        .icon_position(ui::IconPosition::Start)
1036                        .icon_size(ui::IconSize::Small)
1037                        .icon_color(Color::Muted);
1038                }
1039
1040                content = content.child(
1041                    v_flex()
1042                        .gap_2()
1043                        .child(button)
1044                        .when(oauth_in_progress, |this| {
1045                            this.child(
1046                                Label::new("Sign-in in progress...")
1047                                    .size(LabelSize::Small)
1048                                    .color(Color::Muted),
1049                            )
1050                        })
1051                        .when_some(oauth_error, |this, error| {
1052                            this.child(
1053                                v_flex()
1054                                    .gap_1()
1055                                    .child(
1056                                        h_flex()
1057                                            .gap_2()
1058                                            .child(
1059                                                ui::Icon::new(ui::IconName::Warning)
1060                                                    .color(Color::Error)
1061                                                    .size(ui::IconSize::Small),
1062                                            )
1063                                            .child(
1064                                                Label::new("Authentication failed")
1065                                                    .color(Color::Error)
1066                                                    .size(LabelSize::Small),
1067                                            ),
1068                                    )
1069                                    .child(
1070                                        div().pl_6().child(
1071                                            Label::new(error)
1072                                                .color(Color::Error)
1073                                                .size(LabelSize::Small),
1074                                        ),
1075                                    ),
1076                            )
1077                        }),
1078                );
1079            }
1080
1081            // Render API key input if configured (and we have both options, show a separator)
1082            if has_api_key {
1083                if has_oauth {
1084                    content = content.child(
1085                        h_flex()
1086                            .gap_2()
1087                            .items_center()
1088                            .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1089                            .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1090                            .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1091                    );
1092                }
1093
1094                let credential_label = self
1095                    .auth_config
1096                    .as_ref()
1097                    .and_then(|c| c.credential_label.clone())
1098                    .unwrap_or_else(|| "API Key".to_string());
1099
1100                content = content.child(
1101                    v_flex()
1102                        .gap_2()
1103                        .on_action(cx.listener(Self::save_api_key))
1104                        .child(
1105                            Label::new(credential_label)
1106                                .size(LabelSize::Small)
1107                                .color(Color::Muted),
1108                        )
1109                        .child(self.api_key_editor.clone())
1110                        .child(
1111                            Label::new("Enter your API key and press Enter to save")
1112                                .size(LabelSize::Small)
1113                                .color(Color::Muted),
1114                        ),
1115                );
1116            }
1117        }
1118
1119        // Show OpenAI-compatible models notification for OpenAI extension
1120        if self.extension_provider_id == "openai" {
1121            content = content.child(
1122                h_flex()
1123                    .gap_1()
1124                    .child(
1125                        ui::Icon::new(ui::IconName::Info)
1126                            .size(ui::IconSize::Small)
1127                            .color(Color::Muted),
1128                    )
1129                    .child(
1130                        Label::new("Zed also supports OpenAI-compatible models.")
1131                            .size(LabelSize::Small)
1132                            .color(Color::Muted),
1133                    )
1134                    .child(
1135                        ui::Button::new("learn-more", "Learn More")
1136                            .style(ui::ButtonStyle::Subtle)
1137                            .label_size(LabelSize::Small)
1138                            .icon(ui::IconName::ArrowUpRight)
1139                            .icon_size(ui::IconSize::Small)
1140                            .icon_color(Color::Muted)
1141                            .icon_position(ui::IconPosition::End)
1142                            .on_click(|_, _, cx| {
1143                                cx.open_url("https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers");
1144                            }),
1145                    ),
1146            );
1147        }
1148
1149        content.into_any_element()
1150    }
1151}
1152
1153impl Focusable for ExtensionProviderConfigurationView {
1154    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1155        self.api_key_editor.focus_handle(cx)
1156    }
1157}
1158
1159fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1160    let theme_settings = ThemeSettings::get_global(cx);
1161    let colors = cx.theme().colors();
1162    let mut text_style = window.text_style();
1163    text_style.refine(&TextStyleRefinement {
1164        font_family: Some(theme_settings.ui_font.family.clone()),
1165        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1166        font_features: Some(theme_settings.ui_font.features.clone()),
1167        color: Some(colors.text),
1168        ..Default::default()
1169    });
1170
1171    MarkdownStyle {
1172        base_text_style: text_style,
1173        selection_background_color: colors.element_selection_background,
1174        inline_code: TextStyleRefinement {
1175            background_color: Some(colors.editor_background),
1176            ..Default::default()
1177        },
1178        link: TextStyleRefinement {
1179            color: Some(colors.text_accent),
1180            underline: Some(UnderlineStyle {
1181                color: Some(colors.text_accent.opacity(0.5)),
1182                thickness: px(1.),
1183                ..Default::default()
1184            }),
1185            ..Default::default()
1186        },
1187        syntax: cx.theme().syntax().clone(),
1188        ..Default::default()
1189    }
1190}
1191
1192/// An extension-based language model.
1193pub struct ExtensionLanguageModel {
1194    extension: WasmExtension,
1195    model_info: LlmModelInfo,
1196    provider_id: LanguageModelProviderId,
1197    provider_name: LanguageModelProviderName,
1198    provider_info: LlmProviderInfo,
1199}
1200
1201impl LanguageModel for ExtensionLanguageModel {
1202    fn id(&self) -> LanguageModelId {
1203        LanguageModelId::from(self.model_info.id.clone())
1204    }
1205
1206    fn name(&self) -> LanguageModelName {
1207        LanguageModelName::from(self.model_info.name.clone())
1208    }
1209
1210    fn provider_id(&self) -> LanguageModelProviderId {
1211        self.provider_id.clone()
1212    }
1213
1214    fn provider_name(&self) -> LanguageModelProviderName {
1215        self.provider_name.clone()
1216    }
1217
1218    fn telemetry_id(&self) -> String {
1219        format!("extension-{}", self.model_info.id)
1220    }
1221
1222    fn supports_images(&self) -> bool {
1223        self.model_info.capabilities.supports_images
1224    }
1225
1226    fn supports_tools(&self) -> bool {
1227        self.model_info.capabilities.supports_tools
1228    }
1229
1230    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1231        match choice {
1232            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1233            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1234            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1235        }
1236    }
1237
1238    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1239        match self.model_info.capabilities.tool_input_format {
1240            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1241            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1242        }
1243    }
1244
1245    fn max_token_count(&self) -> u64 {
1246        self.model_info.max_token_count
1247    }
1248
1249    fn max_output_tokens(&self) -> Option<u64> {
1250        self.model_info.max_output_tokens
1251    }
1252
1253    fn count_tokens(
1254        &self,
1255        request: LanguageModelRequest,
1256        cx: &App,
1257    ) -> BoxFuture<'static, Result<u64>> {
1258        let extension = self.extension.clone();
1259        let provider_id = self.provider_info.id.clone();
1260        let model_id = self.model_info.id.clone();
1261
1262        let wit_request = convert_request_to_wit(request);
1263
1264        cx.background_spawn(async move {
1265            extension
1266                .call({
1267                    let provider_id = provider_id.clone();
1268                    let model_id = model_id.clone();
1269                    let wit_request = wit_request.clone();
1270                    |ext, store| {
1271                        async move {
1272                            let count = ext
1273                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1274                                .await?
1275                                .map_err(|e| anyhow!("{}", e))?;
1276                            Ok(count)
1277                        }
1278                        .boxed()
1279                    }
1280                })
1281                .await?
1282        })
1283        .boxed()
1284    }
1285
1286    fn stream_completion(
1287        &self,
1288        request: LanguageModelRequest,
1289        _cx: &AsyncApp,
1290    ) -> BoxFuture<
1291        'static,
1292        Result<
1293            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1294            LanguageModelCompletionError,
1295        >,
1296    > {
1297        let extension = self.extension.clone();
1298        let provider_id = self.provider_info.id.clone();
1299        let model_id = self.model_info.id.clone();
1300
1301        let wit_request = convert_request_to_wit(request);
1302
1303        async move {
1304            // Start the stream
1305            let stream_id_result = extension
1306                .call({
1307                    let provider_id = provider_id.clone();
1308                    let model_id = model_id.clone();
1309                    let wit_request = wit_request.clone();
1310                    |ext, store| {
1311                        async move {
1312                            let id = ext
1313                                .call_llm_stream_completion_start(
1314                                    store,
1315                                    &provider_id,
1316                                    &model_id,
1317                                    &wit_request,
1318                                )
1319                                .await?
1320                                .map_err(|e| anyhow!("{}", e))?;
1321                            Ok(id)
1322                        }
1323                        .boxed()
1324                    }
1325                })
1326                .await;
1327
1328            let stream_id = stream_id_result
1329                .map_err(LanguageModelCompletionError::Other)?
1330                .map_err(LanguageModelCompletionError::Other)?;
1331
1332            // Create a stream that polls for events
1333            let stream = futures::stream::unfold(
1334                (extension.clone(), stream_id, false),
1335                move |(extension, stream_id, done)| async move {
1336                    if done {
1337                        return None;
1338                    }
1339
1340                    let result = extension
1341                        .call({
1342                            let stream_id = stream_id.clone();
1343                            |ext, store| {
1344                                async move {
1345                                    let event = ext
1346                                        .call_llm_stream_completion_next(store, &stream_id)
1347                                        .await?
1348                                        .map_err(|e| anyhow!("{}", e))?;
1349                                    Ok(event)
1350                                }
1351                                .boxed()
1352                            }
1353                        })
1354                        .await
1355                        .and_then(|inner| inner);
1356
1357                    match result {
1358                        Ok(Some(event)) => {
1359                            let converted = convert_completion_event(event);
1360                            let is_done =
1361                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1362                            Some((converted, (extension, stream_id, is_done)))
1363                        }
1364                        Ok(None) => {
1365                            // Stream complete, close it
1366                            let _ = extension
1367                                .call({
1368                                    let stream_id = stream_id.clone();
1369                                    |ext, store| {
1370                                        async move {
1371                                            ext.call_llm_stream_completion_close(store, &stream_id)
1372                                                .await?;
1373                                            Ok::<(), anyhow::Error>(())
1374                                        }
1375                                        .boxed()
1376                                    }
1377                                })
1378                                .await;
1379                            None
1380                        }
1381                        Err(e) => Some((
1382                            Err(LanguageModelCompletionError::Other(e)),
1383                            (extension, stream_id, true),
1384                        )),
1385                    }
1386                },
1387            );
1388
1389            Ok(stream.boxed())
1390        }
1391        .boxed()
1392    }
1393
1394    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1395        // Extensions can implement this via llm_cache_configuration
1396        None
1397    }
1398}
1399
1400fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1401    use language_model::{MessageContent, Role};
1402
1403    let messages: Vec<LlmRequestMessage> = request
1404        .messages
1405        .into_iter()
1406        .map(|msg| {
1407            let role = match msg.role {
1408                Role::User => LlmMessageRole::User,
1409                Role::Assistant => LlmMessageRole::Assistant,
1410                Role::System => LlmMessageRole::System,
1411            };
1412
1413            let content: Vec<LlmMessageContent> = msg
1414                .content
1415                .into_iter()
1416                .map(|c| match c {
1417                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1418                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1419                        source: image.source.to_string(),
1420                        width: Some(image.size.width.0 as u32),
1421                        height: Some(image.size.height.0 as u32),
1422                    }),
1423                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1424                        id: tool_use.id.to_string(),
1425                        name: tool_use.name.to_string(),
1426                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1427                        is_input_complete: tool_use.is_input_complete,
1428                        thought_signature: tool_use.thought_signature,
1429                    }),
1430                    MessageContent::ToolResult(tool_result) => {
1431                        let content = match tool_result.content {
1432                            language_model::LanguageModelToolResultContent::Text(text) => {
1433                                LlmToolResultContent::Text(text.to_string())
1434                            }
1435                            language_model::LanguageModelToolResultContent::Image(image) => {
1436                                LlmToolResultContent::Image(LlmImageData {
1437                                    source: image.source.to_string(),
1438                                    width: Some(image.size.width.0 as u32),
1439                                    height: Some(image.size.height.0 as u32),
1440                                })
1441                            }
1442                        };
1443                        LlmMessageContent::ToolResult(LlmToolResult {
1444                            tool_use_id: tool_result.tool_use_id.to_string(),
1445                            tool_name: tool_result.tool_name.to_string(),
1446                            is_error: tool_result.is_error,
1447                            content,
1448                        })
1449                    }
1450                    MessageContent::Thinking { text, signature } => {
1451                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1452                    }
1453                    MessageContent::RedactedThinking(data) => {
1454                        LlmMessageContent::RedactedThinking(data)
1455                    }
1456                })
1457                .collect();
1458
1459            LlmRequestMessage {
1460                role,
1461                content,
1462                cache: msg.cache,
1463            }
1464        })
1465        .collect();
1466
1467    let tools: Vec<LlmToolDefinition> = request
1468        .tools
1469        .into_iter()
1470        .map(|tool| LlmToolDefinition {
1471            name: tool.name,
1472            description: tool.description,
1473            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1474        })
1475        .collect();
1476
1477    let tool_choice = request.tool_choice.map(|tc| match tc {
1478        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1479        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1480        LanguageModelToolChoice::None => LlmToolChoice::None,
1481    });
1482
1483    LlmCompletionRequest {
1484        messages,
1485        tools,
1486        tool_choice,
1487        stop_sequences: request.stop,
1488        temperature: request.temperature,
1489        thinking_allowed: false,
1490        max_tokens: None,
1491    }
1492}
1493
1494fn convert_completion_event(
1495    event: LlmCompletionEvent,
1496) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1497    match event {
1498        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1499            message_id: String::new(),
1500        }),
1501        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1502        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1503            text: thinking.text,
1504            signature: thinking.signature,
1505        }),
1506        LlmCompletionEvent::RedactedThinking(data) => {
1507            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1508        }
1509        LlmCompletionEvent::ToolUse(tool_use) => {
1510            let raw_input = tool_use.input.clone();
1511            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1512            Ok(LanguageModelCompletionEvent::ToolUse(
1513                LanguageModelToolUse {
1514                    id: LanguageModelToolUseId::from(tool_use.id),
1515                    name: tool_use.name.into(),
1516                    raw_input,
1517                    input,
1518                    is_input_complete: tool_use.is_input_complete,
1519                    thought_signature: tool_use.thought_signature,
1520                },
1521            ))
1522        }
1523        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1524            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1525                id: LanguageModelToolUseId::from(error.id),
1526                tool_name: error.tool_name.into(),
1527                raw_input: error.raw_input.into(),
1528                json_parse_error: error.error,
1529            })
1530        }
1531        LlmCompletionEvent::Stop(reason) => {
1532            let stop_reason = match reason {
1533                LlmStopReason::EndTurn => StopReason::EndTurn,
1534                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1535                LlmStopReason::ToolUse => StopReason::ToolUse,
1536                LlmStopReason::Refusal => StopReason::Refusal,
1537            };
1538            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1539        }
1540        LlmCompletionEvent::Usage(usage) => {
1541            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1542                input_tokens: usage.input_tokens,
1543                output_tokens: usage.output_tokens,
1544                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1545                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1546            }))
1547        }
1548        LlmCompletionEvent::ReasoningDetails(json) => {
1549            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1550                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1551            ))
1552        }
1553    }
1554}