llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::LEGACY_LLM_EXTENSION_IDS;
   3use crate::wasm_host::WasmExtension;
   4use collections::HashSet;
   5
   6use crate::wasm_host::wit::{
   7    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   8    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
   9    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  10    LlmToolUse,
  11};
  12use anyhow::{Result, anyhow};
  13use credentials_provider::CredentialsProvider;
  14use editor::Editor;
  15use extension::{LanguageModelAuthConfig, OAuthConfig};
  16use futures::future::BoxFuture;
  17use futures::stream::BoxStream;
  18use futures::{FutureExt, StreamExt};
  19use gpui::Focusable;
  20use gpui::{
  21    AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
  22    MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
  23};
  24use language_model::tool_schema::LanguageModelToolSchemaFormat;
  25use language_model::{
  26    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  27    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  28    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  29    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  30    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  31};
  32use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  33use settings::Settings;
  34use std::sync::Arc;
  35use theme::ThemeSettings;
  36use ui::{Label, LabelSize, prelude::*};
  37use util::ResultExt as _;
  38
  39/// An extension-based language model provider.
  40pub struct ExtensionLanguageModelProvider {
  41    pub extension: WasmExtension,
  42    pub provider_info: LlmProviderInfo,
  43    icon_path: Option<SharedString>,
  44    auth_config: Option<LanguageModelAuthConfig>,
  45    state: Entity<ExtensionLlmProviderState>,
  46}
  47
  48pub struct ExtensionLlmProviderState {
  49    is_authenticated: bool,
  50    available_models: Vec<LlmModelInfo>,
  51    /// Set of env var names that are allowed to be read for this provider.
  52    allowed_env_vars: HashSet<String>,
  53    /// If authenticated via env var, which one was used.
  54    env_var_name_used: Option<String>,
  55}
  56
  57impl EventEmitter<()> for ExtensionLlmProviderState {}
  58
  59impl ExtensionLanguageModelProvider {
  60    pub fn new(
  61        extension: WasmExtension,
  62        provider_info: LlmProviderInfo,
  63        models: Vec<LlmModelInfo>,
  64        is_authenticated: bool,
  65        icon_path: Option<SharedString>,
  66        auth_config: Option<LanguageModelAuthConfig>,
  67        cx: &mut App,
  68    ) -> Self {
  69        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  70
  71        // Build set of allowed env vars for this provider
  72        let settings = ExtensionSettings::get_global(cx);
  73        let is_legacy_extension =
  74            LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
  75
  76        let mut allowed_env_vars = HashSet::default();
  77        if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
  78            for env_var_name in env_vars {
  79                let key = format!("{}:{}", provider_id_string, env_var_name);
  80                // For legacy extensions, auto-allow if env var is set (migration will persist this)
  81                let env_var_is_set = std::env::var(env_var_name)
  82                    .map(|v| !v.is_empty())
  83                    .unwrap_or(false);
  84                if settings.allowed_env_var_providers.contains(key.as_str())
  85                    || (is_legacy_extension && env_var_is_set)
  86                {
  87                    allowed_env_vars.insert(env_var_name.clone());
  88                }
  89            }
  90        }
  91
  92        // Check if any allowed env var is set
  93        let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
  94            if let Ok(value) = std::env::var(env_var_name) {
  95                if !value.is_empty() {
  96                    return Some(env_var_name.clone());
  97                }
  98            }
  99            None
 100        });
 101
 102        let is_authenticated = if env_var_name_used.is_some() {
 103            true
 104        } else {
 105            is_authenticated
 106        };
 107
 108        let state = cx.new(|_| ExtensionLlmProviderState {
 109            is_authenticated,
 110            available_models: models,
 111            allowed_env_vars,
 112            env_var_name_used,
 113        });
 114
 115        Self {
 116            extension,
 117            provider_info,
 118            icon_path,
 119            auth_config,
 120            state,
 121        }
 122    }
 123
 124    fn provider_id_string(&self) -> String {
 125        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 126    }
 127
 128    /// The credential key used for storing the API key in the system keychain.
 129    fn credential_key(&self) -> String {
 130        format!("extension-llm-{}", self.provider_id_string())
 131    }
 132}
 133
 134impl LanguageModelProvider for ExtensionLanguageModelProvider {
 135    fn id(&self) -> LanguageModelProviderId {
 136        LanguageModelProviderId::from(self.provider_id_string())
 137    }
 138
 139    fn name(&self) -> LanguageModelProviderName {
 140        LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
 141    }
 142
 143    fn icon(&self) -> ui::IconName {
 144        ui::IconName::ZedAssistant
 145    }
 146
 147    fn icon_path(&self) -> Option<SharedString> {
 148        self.icon_path.clone()
 149    }
 150
 151    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 152        let state = self.state.read(cx);
 153        state
 154            .available_models
 155            .iter()
 156            .find(|m| m.is_default)
 157            .or_else(|| state.available_models.first())
 158            .map(|model_info| {
 159                Arc::new(ExtensionLanguageModel {
 160                    extension: self.extension.clone(),
 161                    model_info: model_info.clone(),
 162                    provider_id: self.id(),
 163                    provider_name: self.name(),
 164                    provider_info: self.provider_info.clone(),
 165                }) as Arc<dyn LanguageModel>
 166            })
 167    }
 168
 169    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 170        let state = self.state.read(cx);
 171        state
 172            .available_models
 173            .iter()
 174            .find(|m| m.is_default_fast)
 175            .map(|model_info| {
 176                Arc::new(ExtensionLanguageModel {
 177                    extension: self.extension.clone(),
 178                    model_info: model_info.clone(),
 179                    provider_id: self.id(),
 180                    provider_name: self.name(),
 181                    provider_info: self.provider_info.clone(),
 182                }) as Arc<dyn LanguageModel>
 183            })
 184    }
 185
 186    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 187        let state = self.state.read(cx);
 188        state
 189            .available_models
 190            .iter()
 191            .map(|model_info| {
 192                Arc::new(ExtensionLanguageModel {
 193                    extension: self.extension.clone(),
 194                    model_info: model_info.clone(),
 195                    provider_id: self.id(),
 196                    provider_name: self.name(),
 197                    provider_info: self.provider_info.clone(),
 198                }) as Arc<dyn LanguageModel>
 199            })
 200            .collect()
 201    }
 202
 203    fn is_authenticated(&self, cx: &App) -> bool {
 204        // First check cached state
 205        if self.state.read(cx).is_authenticated {
 206            return true;
 207        }
 208
 209        // Also check env var dynamically (in case settings changed after provider creation)
 210        if let Some(ref auth_config) = self.auth_config {
 211            if let Some(ref env_vars) = auth_config.env_vars {
 212                let provider_id_string = self.provider_id_string();
 213                let settings = ExtensionSettings::get_global(cx);
 214                let is_legacy_extension =
 215                    LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
 216
 217                for env_var_name in env_vars {
 218                    let key = format!("{}:{}", provider_id_string, env_var_name);
 219                    // For legacy extensions, auto-allow if env var is set
 220                    let env_var_is_set = std::env::var(env_var_name)
 221                        .map(|v| !v.is_empty())
 222                        .unwrap_or(false);
 223                    if settings.allowed_env_var_providers.contains(key.as_str())
 224                        || (is_legacy_extension && env_var_is_set)
 225                    {
 226                        if let Ok(value) = std::env::var(env_var_name) {
 227                            if !value.is_empty() {
 228                                return true;
 229                            }
 230                        }
 231                    }
 232                }
 233            }
 234        }
 235
 236        false
 237    }
 238
 239    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 240        // Check if already authenticated via is_authenticated
 241        if self.is_authenticated(cx) {
 242            return Task::ready(Ok(()));
 243        }
 244
 245        // Not authenticated - return error indicating credentials not found
 246        Task::ready(Err(AuthenticateError::CredentialsNotFound))
 247    }
 248
 249    fn configuration_view(
 250        &self,
 251        _target_agent: ConfigurationViewTargetAgent,
 252        window: &mut Window,
 253        cx: &mut App,
 254    ) -> AnyView {
 255        let credential_key = self.credential_key();
 256        let extension = self.extension.clone();
 257        let extension_provider_id = self.provider_info.id.clone();
 258        let full_provider_id = self.provider_id_string();
 259        let state = self.state.clone();
 260        let auth_config = self.auth_config.clone();
 261
 262        cx.new(|cx| {
 263            ExtensionProviderConfigurationView::new(
 264                credential_key,
 265                extension,
 266                extension_provider_id,
 267                full_provider_id,
 268                auth_config,
 269                state,
 270                window,
 271                cx,
 272            )
 273        })
 274        .into()
 275    }
 276
 277    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 278        let extension = self.extension.clone();
 279        let provider_id = self.provider_info.id.clone();
 280        let state = self.state.clone();
 281        let credential_key = self.credential_key();
 282
 283        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 284
 285        cx.spawn(async move |cx| {
 286            // Delete from system keychain
 287            credentials_provider
 288                .delete_credentials(&credential_key, cx)
 289                .await
 290                .log_err();
 291
 292            // Call extension's reset_credentials
 293            let result = extension
 294                .call(|extension, store| {
 295                    async move {
 296                        extension
 297                            .call_llm_provider_reset_credentials(store, &provider_id)
 298                            .await
 299                    }
 300                    .boxed()
 301                })
 302                .await;
 303
 304            // Update state
 305            cx.update(|cx| {
 306                state.update(cx, |state, _| {
 307                    state.is_authenticated = false;
 308                });
 309            })?;
 310
 311            match result {
 312                Ok(Ok(Ok(()))) => Ok(()),
 313                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 314                Ok(Err(e)) => Err(e),
 315                Err(e) => Err(e),
 316            }
 317        })
 318    }
 319}
 320
 321impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 322    type ObservableEntity = ExtensionLlmProviderState;
 323
 324    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 325        Some(self.state.clone())
 326    }
 327
 328    fn subscribe<T: 'static>(
 329        &self,
 330        cx: &mut Context<T>,
 331        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 332    ) -> Option<Subscription> {
 333        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 334    }
 335}
 336
 337/// Configuration view for extension-based LLM providers.
 338struct ExtensionProviderConfigurationView {
 339    credential_key: String,
 340    extension: WasmExtension,
 341    extension_provider_id: String,
 342    full_provider_id: String,
 343    auth_config: Option<LanguageModelAuthConfig>,
 344    state: Entity<ExtensionLlmProviderState>,
 345    settings_markdown: Option<Entity<Markdown>>,
 346    api_key_editor: Entity<Editor>,
 347    loading_settings: bool,
 348    loading_credentials: bool,
 349    oauth_in_progress: bool,
 350    oauth_error: Option<String>,
 351    device_user_code: Option<String>,
 352    _subscriptions: Vec<Subscription>,
 353}
 354
 355impl ExtensionProviderConfigurationView {
 356    fn new(
 357        credential_key: String,
 358        extension: WasmExtension,
 359        extension_provider_id: String,
 360        full_provider_id: String,
 361        auth_config: Option<LanguageModelAuthConfig>,
 362        state: Entity<ExtensionLlmProviderState>,
 363        window: &mut Window,
 364        cx: &mut Context<Self>,
 365    ) -> Self {
 366        // Subscribe to state changes
 367        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 368            cx.notify();
 369        });
 370
 371        // Create API key editor
 372        let api_key_editor = cx.new(|cx| {
 373            let mut editor = Editor::single_line(window, cx);
 374            editor.set_placeholder_text("Enter API key...", window, cx);
 375            editor
 376        });
 377
 378        let mut this = Self {
 379            credential_key,
 380            extension,
 381            extension_provider_id,
 382            full_provider_id,
 383            auth_config,
 384            state,
 385            settings_markdown: None,
 386            api_key_editor,
 387            loading_settings: true,
 388            loading_credentials: true,
 389            oauth_in_progress: false,
 390            oauth_error: None,
 391            device_user_code: None,
 392            _subscriptions: vec![state_subscription],
 393        };
 394
 395        // Load settings text from extension
 396        this.load_settings_text(cx);
 397
 398        // Load existing credentials
 399        this.load_credentials(cx);
 400
 401        this
 402    }
 403
 404    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 405        let extension = self.extension.clone();
 406        let provider_id = self.extension_provider_id.clone();
 407
 408        cx.spawn(async move |this, cx| {
 409            let result = extension
 410                .call({
 411                    let provider_id = provider_id.clone();
 412                    |ext, store| {
 413                        async move {
 414                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 415                                .await
 416                        }
 417                        .boxed()
 418                    }
 419                })
 420                .await;
 421
 422            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 423
 424            this.update(cx, |this, cx| {
 425                this.loading_settings = false;
 426                if let Some(text) = settings_text {
 427                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 428                    this.settings_markdown = Some(markdown);
 429                }
 430                cx.notify();
 431            })
 432            .log_err();
 433        })
 434        .detach();
 435    }
 436
 437    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 438        let credential_key = self.credential_key.clone();
 439        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 440        let state = self.state.clone();
 441
 442        // Check if we should use env var (already set in state during provider construction)
 443        let using_env_var = self.state.read(cx).env_var_name_used.is_some();
 444
 445        cx.spawn(async move |this, cx| {
 446            // If using env var, we're already authenticated
 447            if using_env_var {
 448                this.update(cx, |this, cx| {
 449                    this.loading_credentials = false;
 450                    cx.notify();
 451                })
 452                .log_err();
 453                return;
 454            }
 455
 456            let credentials = credentials_provider
 457                .read_credentials(&credential_key, cx)
 458                .await
 459                .log_err()
 460                .flatten();
 461
 462            let has_credentials = credentials.is_some();
 463
 464            // Update authentication state based on stored credentials
 465            cx.update(|cx| {
 466                state.update(cx, |state, cx| {
 467                    state.is_authenticated = has_credentials;
 468                    cx.notify();
 469                });
 470            })
 471            .log_err();
 472
 473            this.update(cx, |this, cx| {
 474                this.loading_credentials = false;
 475                cx.notify();
 476            })
 477            .log_err();
 478        })
 479        .detach();
 480    }
 481
 482    fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
 483        let full_provider_id = self.full_provider_id.clone();
 484        let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
 485
 486        let state = self.state.clone();
 487        let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
 488
 489        // Update settings file
 490        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
 491            let settings_key = settings_key.clone();
 492            move |settings, _| {
 493                let allowed = settings
 494                    .extension
 495                    .allowed_env_var_providers
 496                    .get_or_insert_with(Vec::new);
 497
 498                if currently_allowed {
 499                    allowed.retain(|id| id.as_ref() != settings_key.as_ref());
 500                } else {
 501                    if !allowed
 502                        .iter()
 503                        .any(|id| id.as_ref() == settings_key.as_ref())
 504                    {
 505                        allowed.push(settings_key.clone());
 506                    }
 507                }
 508            }
 509        });
 510
 511        // Update local state
 512        let new_allowed = !currently_allowed;
 513        let env_var_name_clone = env_var_name.clone();
 514
 515        state.update(cx, |state, cx| {
 516            if new_allowed {
 517                state.allowed_env_vars.insert(env_var_name_clone.clone());
 518                // Check if this env var is set and update env_var_name_used
 519                if let Ok(value) = std::env::var(&env_var_name_clone) {
 520                    if !value.is_empty() && state.env_var_name_used.is_none() {
 521                        state.env_var_name_used = Some(env_var_name_clone);
 522                        state.is_authenticated = true;
 523                    }
 524                }
 525            } else {
 526                state.allowed_env_vars.remove(&env_var_name_clone);
 527                // If this was the env var being used, clear it and find another
 528                if state.env_var_name_used.as_ref() == Some(&env_var_name_clone) {
 529                    state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
 530                        if let Ok(value) = std::env::var(var) {
 531                            if !value.is_empty() {
 532                                return Some(var.clone());
 533                            }
 534                        }
 535                        None
 536                    });
 537                    if state.env_var_name_used.is_none() {
 538                        // No env var auth available, need to check keychain
 539                        state.is_authenticated = false;
 540                    }
 541                }
 542            }
 543            cx.notify();
 544        });
 545
 546        // If all env vars are being disabled, reload credentials from keychain
 547        if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
 548            self.reload_keychain_credentials(cx);
 549        }
 550
 551        cx.notify();
 552    }
 553
 554    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 555        let credential_key = self.credential_key.clone();
 556        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 557        let state = self.state.clone();
 558
 559        cx.spawn(async move |_this, cx| {
 560            let credentials = credentials_provider
 561                .read_credentials(&credential_key, cx)
 562                .await
 563                .log_err()
 564                .flatten();
 565
 566            let has_credentials = credentials.is_some();
 567
 568            cx.update(|cx| {
 569                state.update(cx, |state, cx| {
 570                    state.is_authenticated = has_credentials;
 571                    cx.notify();
 572                });
 573            })
 574            .log_err();
 575        })
 576        .detach();
 577    }
 578
 579    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 580        let api_key = self.api_key_editor.read(cx).text(cx);
 581        if api_key.is_empty() {
 582            return;
 583        }
 584
 585        // Clear the editor
 586        self.api_key_editor
 587            .update(cx, |editor, cx| editor.set_text("", window, cx));
 588
 589        let credential_key = self.credential_key.clone();
 590        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 591        let state = self.state.clone();
 592
 593        cx.spawn(async move |_this, cx| {
 594            // Store in system keychain
 595            credentials_provider
 596                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 597                .await
 598                .log_err();
 599
 600            // Update state to authenticated
 601            cx.update(|cx| {
 602                state.update(cx, |state, cx| {
 603                    state.is_authenticated = true;
 604                    cx.notify();
 605                });
 606            })
 607            .log_err();
 608        })
 609        .detach();
 610    }
 611
 612    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 613        // Clear the editor
 614        self.api_key_editor
 615            .update(cx, |editor, cx| editor.set_text("", window, cx));
 616
 617        let credential_key = self.credential_key.clone();
 618        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 619        let state = self.state.clone();
 620
 621        cx.spawn(async move |_this, cx| {
 622            // Delete from system keychain
 623            credentials_provider
 624                .delete_credentials(&credential_key, cx)
 625                .await
 626                .log_err();
 627
 628            // Update state to unauthenticated
 629            cx.update(|cx| {
 630                state.update(cx, |state, cx| {
 631                    state.is_authenticated = false;
 632                    cx.notify();
 633                });
 634            })
 635            .log_err();
 636        })
 637        .detach();
 638    }
 639
 640    fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
 641        if self.oauth_in_progress {
 642            return;
 643        }
 644
 645        self.oauth_in_progress = true;
 646        self.oauth_error = None;
 647        self.device_user_code = None;
 648        cx.notify();
 649
 650        let extension = self.extension.clone();
 651        let provider_id = self.extension_provider_id.clone();
 652        let state = self.state.clone();
 653
 654        cx.spawn(async move |this, cx| {
 655            // Step 1: Start device flow - opens browser and returns user code
 656            let start_result = extension
 657                .call({
 658                    let provider_id = provider_id.clone();
 659                    |ext, store| {
 660                        async move {
 661                            ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
 662                                .await
 663                        }
 664                        .boxed()
 665                    }
 666                })
 667                .await;
 668
 669            let user_code = match start_result {
 670                Ok(Ok(Ok(code))) => code,
 671                Ok(Ok(Err(e))) => {
 672                    log::error!("Device flow start failed: {}", e);
 673                    this.update(cx, |this, cx| {
 674                        this.oauth_in_progress = false;
 675                        this.oauth_error = Some(e);
 676                        cx.notify();
 677                    })
 678                    .log_err();
 679                    return;
 680                }
 681                Ok(Err(e)) | Err(e) => {
 682                    log::error!("Device flow start error: {}", e);
 683                    this.update(cx, |this, cx| {
 684                        this.oauth_in_progress = false;
 685                        this.oauth_error = Some(e.to_string());
 686                        cx.notify();
 687                    })
 688                    .log_err();
 689                    return;
 690                }
 691            };
 692
 693            // Update UI to show the user code before polling
 694            this.update(cx, |this, cx| {
 695                this.device_user_code = Some(user_code);
 696                cx.notify();
 697            })
 698            .log_err();
 699
 700            // Step 2: Poll for authentication completion
 701            let poll_result = extension
 702                .call({
 703                    let provider_id = provider_id.clone();
 704                    |ext, store| {
 705                        async move {
 706                            ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
 707                                .await
 708                        }
 709                        .boxed()
 710                    }
 711                })
 712                .await;
 713
 714            let error_message = match poll_result {
 715                Ok(Ok(Ok(()))) => {
 716                    // After successful auth, refresh the models list
 717                    let models_result = extension
 718                        .call({
 719                            let provider_id = provider_id.clone();
 720                            |ext, store| {
 721                                async move {
 722                                    ext.call_llm_provider_models(store, &provider_id).await
 723                                }
 724                                .boxed()
 725                            }
 726                        })
 727                        .await;
 728
 729                    let new_models: Vec<LlmModelInfo> = match models_result {
 730                        Ok(Ok(Ok(models))) => models,
 731                        _ => Vec::new(),
 732                    };
 733
 734                    cx.update(|cx| {
 735                        state.update(cx, |state, cx| {
 736                            state.is_authenticated = true;
 737                            state.available_models = new_models;
 738                            cx.notify();
 739                        });
 740                    })
 741                    .log_err();
 742                    None
 743                }
 744                Ok(Ok(Err(e))) => {
 745                    log::error!("Device flow poll failed: {}", e);
 746                    Some(e)
 747                }
 748                Ok(Err(e)) | Err(e) => {
 749                    log::error!("Device flow poll error: {}", e);
 750                    Some(e.to_string())
 751                }
 752            };
 753
 754            this.update(cx, |this, cx| {
 755                this.oauth_in_progress = false;
 756                this.oauth_error = error_message;
 757                this.device_user_code = None;
 758                cx.notify();
 759            })
 760            .log_err();
 761        })
 762        .detach();
 763    }
 764
 765    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 766        self.state.read(cx).is_authenticated
 767    }
 768
 769    fn has_oauth_config(&self) -> bool {
 770        self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
 771    }
 772
 773    fn oauth_config(&self) -> Option<&OAuthConfig> {
 774        self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
 775    }
 776
 777    fn has_api_key_config(&self) -> bool {
 778        // API key is available if there's a credential_label or no oauth-only config
 779        self.auth_config
 780            .as_ref()
 781            .map(|c| c.credential_label.is_some() || c.oauth.is_none())
 782            .unwrap_or(true)
 783    }
 784}
 785
 786impl gpui::Render for ExtensionProviderConfigurationView {
 787    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 788        let is_loading = self.loading_settings || self.loading_credentials;
 789        let is_authenticated = self.is_authenticated(cx);
 790        let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
 791        let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
 792        let has_oauth = self.has_oauth_config();
 793        let has_api_key = self.has_api_key_config();
 794
 795        if is_loading {
 796            return v_flex()
 797                .gap_2()
 798                .child(Label::new("Loading...").color(Color::Muted))
 799                .into_any_element();
 800        }
 801
 802        let mut content = v_flex().gap_4().size_full();
 803
 804        // Render settings markdown if available
 805        if let Some(markdown) = &self.settings_markdown {
 806            let style = settings_markdown_style(_window, cx);
 807            content = content.child(MarkdownElement::new(markdown.clone(), style));
 808        }
 809
 810        // Render env var checkboxes - one for each env var the extension declares
 811        if let Some(auth_config) = &self.auth_config {
 812            if let Some(env_vars) = &auth_config.env_vars {
 813                for env_var_name in env_vars {
 814                    let is_allowed = allowed_env_vars.contains(env_var_name);
 815                    let checkbox_label =
 816                        format!("Read API key from {} environment variable", env_var_name);
 817                    let env_var_for_click = env_var_name.clone();
 818
 819                    content = content.child(
 820                        h_flex()
 821                            .gap_2()
 822                            .child(
 823                                ui::Checkbox::new(
 824                                    SharedString::from(format!("env-var-{}", env_var_name)),
 825                                    is_allowed.into(),
 826                                )
 827                                .on_click(cx.listener(
 828                                    move |this, _, _window, cx| {
 829                                        this.toggle_env_var_permission(
 830                                            env_var_for_click.clone(),
 831                                            cx,
 832                                        );
 833                                    },
 834                                )),
 835                            )
 836                            .child(Label::new(checkbox_label).size(LabelSize::Small)),
 837                    );
 838                }
 839
 840                // Show status if any env var is being used
 841                if let Some(used_var) = &env_var_name_used {
 842                    let tooltip_label = format!(
 843                        "To reset this API key, unset the {} environment variable.",
 844                        used_var
 845                    );
 846                    content = content.child(
 847                        h_flex()
 848                            .mt_0p5()
 849                            .p_1()
 850                            .justify_between()
 851                            .rounded_md()
 852                            .border_1()
 853                            .border_color(cx.theme().colors().border)
 854                            .bg(cx.theme().colors().background)
 855                            .child(
 856                                h_flex()
 857                                    .flex_1()
 858                                    .min_w_0()
 859                                    .gap_1()
 860                                    .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
 861                                    .child(
 862                                        Label::new(format!(
 863                                            "API key set in {} environment variable",
 864                                            used_var
 865                                        ))
 866                                        .truncate(),
 867                                    ),
 868                            )
 869                            .child(
 870                                ui::Button::new("reset-key", "Reset Key")
 871                                    .label_size(LabelSize::Small)
 872                                    .icon(ui::IconName::Undo)
 873                                    .icon_size(ui::IconSize::Small)
 874                                    .icon_color(Color::Muted)
 875                                    .icon_position(ui::IconPosition::Start)
 876                                    .disabled(true)
 877                                    .tooltip(ui::Tooltip::text(tooltip_label)),
 878                            ),
 879                    );
 880                    return content.into_any_element();
 881                }
 882            }
 883        }
 884
 885        // If authenticated, show success state with sign out option
 886        if is_authenticated && env_var_name_used.is_none() {
 887            let reset_label = if has_oauth && !has_api_key {
 888                "Sign Out"
 889            } else {
 890                "Reset Key"
 891            };
 892
 893            let status_label = if has_oauth && !has_api_key {
 894                "Signed in"
 895            } else {
 896                "API key configured"
 897            };
 898
 899            content = content.child(
 900                h_flex()
 901                    .mt_0p5()
 902                    .p_1()
 903                    .justify_between()
 904                    .rounded_md()
 905                    .border_1()
 906                    .border_color(cx.theme().colors().border)
 907                    .bg(cx.theme().colors().background)
 908                    .child(
 909                        h_flex()
 910                            .flex_1()
 911                            .min_w_0()
 912                            .gap_1()
 913                            .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
 914                            .child(Label::new(status_label).truncate()),
 915                    )
 916                    .child(
 917                        ui::Button::new("reset-key", reset_label)
 918                            .label_size(LabelSize::Small)
 919                            .icon(ui::IconName::Undo)
 920                            .icon_size(ui::IconSize::Small)
 921                            .icon_color(Color::Muted)
 922                            .icon_position(ui::IconPosition::Start)
 923                            .on_click(cx.listener(|this, _, window, cx| {
 924                                this.reset_api_key(window, cx);
 925                            })),
 926                    ),
 927            );
 928
 929            return content.into_any_element();
 930        }
 931
 932        // Not authenticated - show available auth options
 933        if env_var_name_used.is_none() {
 934            // Render OAuth sign-in button if configured
 935            if has_oauth {
 936                let oauth_config = self.oauth_config();
 937                let button_label = oauth_config
 938                    .and_then(|c| c.sign_in_button_label.clone())
 939                    .unwrap_or_else(|| "Sign In".to_string());
 940                let button_icon = oauth_config
 941                    .and_then(|c| c.sign_in_button_icon.as_ref())
 942                    .and_then(|icon_name| match icon_name.as_str() {
 943                        "github" => Some(ui::IconName::Github),
 944                        _ => None,
 945                    });
 946
 947                let oauth_in_progress = self.oauth_in_progress;
 948
 949                let oauth_error = self.oauth_error.clone();
 950
 951                content = content.child(
 952                    v_flex()
 953                        .gap_2()
 954                        .child({
 955                            let mut button = ui::Button::new("oauth-sign-in", button_label)
 956                                .full_width()
 957                                .style(ui::ButtonStyle::Outlined)
 958                                .disabled(oauth_in_progress)
 959                                .on_click(cx.listener(|this, _, _window, cx| {
 960                                    this.start_oauth_sign_in(cx);
 961                                }));
 962                            if let Some(icon) = button_icon {
 963                                button = button
 964                                    .icon(icon)
 965                                    .icon_position(ui::IconPosition::Start)
 966                                    .icon_size(ui::IconSize::Small)
 967                                    .icon_color(Color::Muted);
 968                            }
 969                            button
 970                        })
 971                        .when(oauth_in_progress, |this| {
 972                            let user_code = self.device_user_code.clone();
 973                            this.child(
 974                                v_flex()
 975                                    .gap_1()
 976                                    .when_some(user_code, |this, code| {
 977                                        let copied = cx
 978                                            .read_from_clipboard()
 979                                            .map(|item| item.text().as_ref() == Some(&code))
 980                                            .unwrap_or(false);
 981                                        let code_for_click = code.clone();
 982                                        this.child(
 983                                            h_flex()
 984                                                .gap_1()
 985                                                .child(
 986                                                    Label::new("Enter code:")
 987                                                        .size(LabelSize::Small)
 988                                                        .color(Color::Muted),
 989                                                )
 990                                                .child(
 991                                                    h_flex()
 992                                                        .gap_1()
 993                                                        .px_1()
 994                                                        .border_1()
 995                                                        .border_color(cx.theme().colors().border)
 996                                                        .rounded_sm()
 997                                                        .cursor_pointer()
 998                                                        .on_mouse_down(
 999                                                            MouseButton::Left,
1000                                                            move |_, window, cx| {
1001                                                                cx.write_to_clipboard(
1002                                                                    ClipboardItem::new_string(
1003                                                                        code_for_click.clone(),
1004                                                                    ),
1005                                                                );
1006                                                                window.refresh();
1007                                                            },
1008                                                        )
1009                                                        .child(
1010                                                            Label::new(code)
1011                                                                .size(LabelSize::Small)
1012                                                                .color(Color::Accent),
1013                                                        )
1014                                                        .child(
1015                                                            ui::Icon::new(if copied {
1016                                                                ui::IconName::Check
1017                                                            } else {
1018                                                                ui::IconName::Copy
1019                                                            })
1020                                                            .size(ui::IconSize::Small)
1021                                                            .color(if copied {
1022                                                                Color::Success
1023                                                            } else {
1024                                                                Color::Muted
1025                                                            }),
1026                                                        ),
1027                                                ),
1028                                        )
1029                                    })
1030                                    .child(
1031                                        Label::new("Waiting for authorization in browser...")
1032                                            .size(LabelSize::Small)
1033                                            .color(Color::Muted),
1034                                    ),
1035                            )
1036                        })
1037                        .when_some(oauth_error, |this, error| {
1038                            this.child(
1039                                v_flex()
1040                                    .gap_1()
1041                                    .child(
1042                                        h_flex()
1043                                            .gap_2()
1044                                            .child(
1045                                                ui::Icon::new(ui::IconName::Warning)
1046                                                    .color(Color::Error)
1047                                                    .size(ui::IconSize::Small),
1048                                            )
1049                                            .child(
1050                                                Label::new("Authentication failed")
1051                                                    .color(Color::Error)
1052                                                    .size(LabelSize::Small),
1053                                            ),
1054                                    )
1055                                    .child(
1056                                        div().pl_6().child(
1057                                            Label::new(error)
1058                                                .color(Color::Error)
1059                                                .size(LabelSize::Small),
1060                                        ),
1061                                    ),
1062                            )
1063                        }),
1064                );
1065            }
1066
1067            // Render API key input if configured (and we have both options, show a separator)
1068            if has_api_key {
1069                if has_oauth {
1070                    content = content.child(
1071                        h_flex()
1072                            .gap_2()
1073                            .items_center()
1074                            .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1075                            .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1076                            .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1077                    );
1078                }
1079
1080                let credential_label = self
1081                    .auth_config
1082                    .as_ref()
1083                    .and_then(|c| c.credential_label.clone())
1084                    .unwrap_or_else(|| "API Key".to_string());
1085
1086                content = content.child(
1087                    v_flex()
1088                        .gap_2()
1089                        .on_action(cx.listener(Self::save_api_key))
1090                        .child(
1091                            Label::new(credential_label)
1092                                .size(LabelSize::Small)
1093                                .color(Color::Muted),
1094                        )
1095                        .child(self.api_key_editor.clone())
1096                        .child(
1097                            Label::new("Enter your API key and press Enter to save")
1098                                .size(LabelSize::Small)
1099                                .color(Color::Muted),
1100                        ),
1101                );
1102            }
1103        }
1104
1105        // Show OpenAI-compatible models notification for OpenAI extension
1106        if self.extension_provider_id == "openai" {
1107            content = content.child(
1108                h_flex()
1109                    .gap_1()
1110                    .child(
1111                        ui::Icon::new(ui::IconName::Info)
1112                            .size(ui::IconSize::Small)
1113                            .color(Color::Muted),
1114                    )
1115                    .child(
1116                        Label::new("Zed also supports OpenAI-compatible models.")
1117                            .size(LabelSize::Small)
1118                            .color(Color::Muted),
1119                    )
1120                    .child(
1121                        ui::Button::new("learn-more", "Learn More")
1122                            .style(ui::ButtonStyle::Subtle)
1123                            .label_size(LabelSize::Small)
1124                            .icon(ui::IconName::ArrowUpRight)
1125                            .icon_size(ui::IconSize::Small)
1126                            .icon_color(Color::Muted)
1127                            .icon_position(ui::IconPosition::End)
1128                            .on_click(|_, _, cx| {
1129                                cx.open_url("https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers");
1130                            }),
1131                    ),
1132            );
1133        }
1134
1135        content.into_any_element()
1136    }
1137}
1138
1139impl Focusable for ExtensionProviderConfigurationView {
1140    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1141        self.api_key_editor.focus_handle(cx)
1142    }
1143}
1144
1145fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1146    let theme_settings = ThemeSettings::get_global(cx);
1147    let colors = cx.theme().colors();
1148    let mut text_style = window.text_style();
1149    text_style.refine(&TextStyleRefinement {
1150        font_family: Some(theme_settings.ui_font.family.clone()),
1151        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1152        font_features: Some(theme_settings.ui_font.features.clone()),
1153        color: Some(colors.text),
1154        ..Default::default()
1155    });
1156
1157    MarkdownStyle {
1158        base_text_style: text_style,
1159        selection_background_color: colors.element_selection_background,
1160        inline_code: TextStyleRefinement {
1161            background_color: Some(colors.editor_background),
1162            ..Default::default()
1163        },
1164        link: TextStyleRefinement {
1165            color: Some(colors.text_accent),
1166            underline: Some(UnderlineStyle {
1167                color: Some(colors.text_accent.opacity(0.5)),
1168                thickness: px(1.),
1169                ..Default::default()
1170            }),
1171            ..Default::default()
1172        },
1173        syntax: cx.theme().syntax().clone(),
1174        ..Default::default()
1175    }
1176}
1177
1178/// An extension-based language model.
1179pub struct ExtensionLanguageModel {
1180    extension: WasmExtension,
1181    model_info: LlmModelInfo,
1182    provider_id: LanguageModelProviderId,
1183    provider_name: LanguageModelProviderName,
1184    provider_info: LlmProviderInfo,
1185}
1186
1187impl LanguageModel for ExtensionLanguageModel {
1188    fn id(&self) -> LanguageModelId {
1189        LanguageModelId::from(self.model_info.id.clone())
1190    }
1191
1192    fn name(&self) -> LanguageModelName {
1193        LanguageModelName::from(self.model_info.name.clone())
1194    }
1195
1196    fn provider_id(&self) -> LanguageModelProviderId {
1197        self.provider_id.clone()
1198    }
1199
1200    fn provider_name(&self) -> LanguageModelProviderName {
1201        self.provider_name.clone()
1202    }
1203
1204    fn telemetry_id(&self) -> String {
1205        format!("extension-{}", self.model_info.id)
1206    }
1207
1208    fn supports_images(&self) -> bool {
1209        self.model_info.capabilities.supports_images
1210    }
1211
1212    fn supports_tools(&self) -> bool {
1213        self.model_info.capabilities.supports_tools
1214    }
1215
1216    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1217        match choice {
1218            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1219            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1220            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1221        }
1222    }
1223
1224    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1225        match self.model_info.capabilities.tool_input_format {
1226            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1227            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1228        }
1229    }
1230
1231    fn max_token_count(&self) -> u64 {
1232        self.model_info.max_token_count
1233    }
1234
1235    fn max_output_tokens(&self) -> Option<u64> {
1236        self.model_info.max_output_tokens
1237    }
1238
1239    fn count_tokens(
1240        &self,
1241        request: LanguageModelRequest,
1242        cx: &App,
1243    ) -> BoxFuture<'static, Result<u64>> {
1244        let extension = self.extension.clone();
1245        let provider_id = self.provider_info.id.clone();
1246        let model_id = self.model_info.id.clone();
1247
1248        let wit_request = convert_request_to_wit(request);
1249
1250        cx.background_spawn(async move {
1251            extension
1252                .call({
1253                    let provider_id = provider_id.clone();
1254                    let model_id = model_id.clone();
1255                    let wit_request = wit_request.clone();
1256                    |ext, store| {
1257                        async move {
1258                            let count = ext
1259                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1260                                .await?
1261                                .map_err(|e| anyhow!("{}", e))?;
1262                            Ok(count)
1263                        }
1264                        .boxed()
1265                    }
1266                })
1267                .await?
1268        })
1269        .boxed()
1270    }
1271
1272    fn stream_completion(
1273        &self,
1274        request: LanguageModelRequest,
1275        _cx: &AsyncApp,
1276    ) -> BoxFuture<
1277        'static,
1278        Result<
1279            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1280            LanguageModelCompletionError,
1281        >,
1282    > {
1283        let extension = self.extension.clone();
1284        let provider_id = self.provider_info.id.clone();
1285        let model_id = self.model_info.id.clone();
1286
1287        let wit_request = convert_request_to_wit(request);
1288
1289        async move {
1290            // Start the stream
1291            let stream_id_result = extension
1292                .call({
1293                    let provider_id = provider_id.clone();
1294                    let model_id = model_id.clone();
1295                    let wit_request = wit_request.clone();
1296                    |ext, store| {
1297                        async move {
1298                            let id = ext
1299                                .call_llm_stream_completion_start(
1300                                    store,
1301                                    &provider_id,
1302                                    &model_id,
1303                                    &wit_request,
1304                                )
1305                                .await?
1306                                .map_err(|e| anyhow!("{}", e))?;
1307                            Ok(id)
1308                        }
1309                        .boxed()
1310                    }
1311                })
1312                .await;
1313
1314            let stream_id = stream_id_result
1315                .map_err(LanguageModelCompletionError::Other)?
1316                .map_err(LanguageModelCompletionError::Other)?;
1317
1318            // Create a stream that polls for events
1319            let stream = futures::stream::unfold(
1320                (extension.clone(), stream_id, false),
1321                move |(extension, stream_id, done)| async move {
1322                    if done {
1323                        return None;
1324                    }
1325
1326                    let result = extension
1327                        .call({
1328                            let stream_id = stream_id.clone();
1329                            |ext, store| {
1330                                async move {
1331                                    let event = ext
1332                                        .call_llm_stream_completion_next(store, &stream_id)
1333                                        .await?
1334                                        .map_err(|e| anyhow!("{}", e))?;
1335                                    Ok(event)
1336                                }
1337                                .boxed()
1338                            }
1339                        })
1340                        .await
1341                        .and_then(|inner| inner);
1342
1343                    match result {
1344                        Ok(Some(event)) => {
1345                            let converted = convert_completion_event(event);
1346                            let is_done =
1347                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1348                            Some((converted, (extension, stream_id, is_done)))
1349                        }
1350                        Ok(None) => {
1351                            // Stream complete, close it
1352                            let _ = extension
1353                                .call({
1354                                    let stream_id = stream_id.clone();
1355                                    |ext, store| {
1356                                        async move {
1357                                            ext.call_llm_stream_completion_close(store, &stream_id)
1358                                                .await?;
1359                                            Ok::<(), anyhow::Error>(())
1360                                        }
1361                                        .boxed()
1362                                    }
1363                                })
1364                                .await;
1365                            None
1366                        }
1367                        Err(e) => Some((
1368                            Err(LanguageModelCompletionError::Other(e)),
1369                            (extension, stream_id, true),
1370                        )),
1371                    }
1372                },
1373            );
1374
1375            Ok(stream.boxed())
1376        }
1377        .boxed()
1378    }
1379
1380    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1381        // Extensions can implement this via llm_cache_configuration
1382        None
1383    }
1384}
1385
1386fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1387    use language_model::{MessageContent, Role};
1388
1389    let messages: Vec<LlmRequestMessage> = request
1390        .messages
1391        .into_iter()
1392        .map(|msg| {
1393            let role = match msg.role {
1394                Role::User => LlmMessageRole::User,
1395                Role::Assistant => LlmMessageRole::Assistant,
1396                Role::System => LlmMessageRole::System,
1397            };
1398
1399            let content: Vec<LlmMessageContent> = msg
1400                .content
1401                .into_iter()
1402                .map(|c| match c {
1403                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1404                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1405                        source: image.source.to_string(),
1406                        width: Some(image.size.width.0 as u32),
1407                        height: Some(image.size.height.0 as u32),
1408                    }),
1409                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1410                        id: tool_use.id.to_string(),
1411                        name: tool_use.name.to_string(),
1412                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1413                        thought_signature: tool_use.thought_signature,
1414                    }),
1415                    MessageContent::ToolResult(tool_result) => {
1416                        let content = match tool_result.content {
1417                            language_model::LanguageModelToolResultContent::Text(text) => {
1418                                LlmToolResultContent::Text(text.to_string())
1419                            }
1420                            language_model::LanguageModelToolResultContent::Image(image) => {
1421                                LlmToolResultContent::Image(LlmImageData {
1422                                    source: image.source.to_string(),
1423                                    width: Some(image.size.width.0 as u32),
1424                                    height: Some(image.size.height.0 as u32),
1425                                })
1426                            }
1427                        };
1428                        LlmMessageContent::ToolResult(LlmToolResult {
1429                            tool_use_id: tool_result.tool_use_id.to_string(),
1430                            tool_name: tool_result.tool_name.to_string(),
1431                            is_error: tool_result.is_error,
1432                            content,
1433                        })
1434                    }
1435                    MessageContent::Thinking { text, signature } => {
1436                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1437                    }
1438                    MessageContent::RedactedThinking(data) => {
1439                        LlmMessageContent::RedactedThinking(data)
1440                    }
1441                })
1442                .collect();
1443
1444            LlmRequestMessage {
1445                role,
1446                content,
1447                cache: msg.cache,
1448            }
1449        })
1450        .collect();
1451
1452    let tools: Vec<LlmToolDefinition> = request
1453        .tools
1454        .into_iter()
1455        .map(|tool| LlmToolDefinition {
1456            name: tool.name,
1457            description: tool.description,
1458            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1459        })
1460        .collect();
1461
1462    let tool_choice = request.tool_choice.map(|tc| match tc {
1463        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1464        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1465        LanguageModelToolChoice::None => LlmToolChoice::None,
1466    });
1467
1468    LlmCompletionRequest {
1469        messages,
1470        tools,
1471        tool_choice,
1472        stop_sequences: request.stop,
1473        temperature: request.temperature,
1474        thinking_allowed: false,
1475        max_tokens: None,
1476    }
1477}
1478
1479fn convert_completion_event(
1480    event: LlmCompletionEvent,
1481) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1482    match event {
1483        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1484            message_id: String::new(),
1485        }),
1486        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1487        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1488            text: thinking.text,
1489            signature: thinking.signature,
1490        }),
1491        LlmCompletionEvent::RedactedThinking(data) => {
1492            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1493        }
1494        LlmCompletionEvent::ToolUse(tool_use) => {
1495            let raw_input = tool_use.input.clone();
1496            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1497            Ok(LanguageModelCompletionEvent::ToolUse(
1498                LanguageModelToolUse {
1499                    id: LanguageModelToolUseId::from(tool_use.id),
1500                    name: tool_use.name.into(),
1501                    raw_input,
1502                    input,
1503                    is_input_complete: true,
1504                    thought_signature: tool_use.thought_signature,
1505                },
1506            ))
1507        }
1508        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1509            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1510                id: LanguageModelToolUseId::from(error.id),
1511                tool_name: error.tool_name.into(),
1512                raw_input: error.raw_input.into(),
1513                json_parse_error: error.error,
1514            })
1515        }
1516        LlmCompletionEvent::Stop(reason) => {
1517            let stop_reason = match reason {
1518                LlmStopReason::EndTurn => StopReason::EndTurn,
1519                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1520                LlmStopReason::ToolUse => StopReason::ToolUse,
1521                LlmStopReason::Refusal => StopReason::Refusal,
1522            };
1523            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1524        }
1525        LlmCompletionEvent::Usage(usage) => {
1526            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1527                input_tokens: usage.input_tokens,
1528                output_tokens: usage.output_tokens,
1529                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1530                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1531            }))
1532        }
1533        LlmCompletionEvent::ReasoningDetails(json) => {
1534            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1535                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1536            ))
1537        }
1538    }
1539}