llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::wasm_host::WasmExtension;
   3
   4use crate::wasm_host::wit::{
   5    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   6    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
   7    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
   8    LlmToolUse,
   9};
  10use anyhow::{Result, anyhow};
  11use credentials_provider::CredentialsProvider;
  12use editor::Editor;
  13use extension::LanguageModelAuthConfig;
  14use futures::future::BoxFuture;
  15use futures::stream::BoxStream;
  16use futures::{FutureExt, StreamExt};
  17use gpui::Focusable;
  18use gpui::{
  19    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
  20    TextStyleRefinement, UnderlineStyle, Window, px,
  21};
  22use language_model::tool_schema::LanguageModelToolSchemaFormat;
  23use language_model::{
  24    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  25    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  27    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  28    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  29};
  30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  31use settings::Settings;
  32use std::sync::Arc;
  33use theme::ThemeSettings;
  34use ui::{Label, LabelSize, prelude::*};
  35use util::ResultExt as _;
  36
  37/// An extension-based language model provider.
  38pub struct ExtensionLanguageModelProvider {
  39    pub extension: WasmExtension,
  40    pub provider_info: LlmProviderInfo,
  41    icon_path: Option<SharedString>,
  42    auth_config: Option<LanguageModelAuthConfig>,
  43    state: Entity<ExtensionLlmProviderState>,
  44}
  45
  46pub struct ExtensionLlmProviderState {
  47    is_authenticated: bool,
  48    available_models: Vec<LlmModelInfo>,
  49    env_var_allowed: bool,
  50    api_key_from_env: bool,
  51}
  52
  53impl EventEmitter<()> for ExtensionLlmProviderState {}
  54
  55impl ExtensionLanguageModelProvider {
  56    pub fn new(
  57        extension: WasmExtension,
  58        provider_info: LlmProviderInfo,
  59        models: Vec<LlmModelInfo>,
  60        is_authenticated: bool,
  61        icon_path: Option<SharedString>,
  62        auth_config: Option<LanguageModelAuthConfig>,
  63        cx: &mut App,
  64    ) -> Self {
  65        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  66        let env_var_allowed = ExtensionSettings::get_global(cx)
  67            .allowed_env_var_providers
  68            .contains(provider_id_string.as_str());
  69
  70        let (is_authenticated, api_key_from_env) =
  71            if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
  72                let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
  73                if let Ok(value) = std::env::var(env_var_name) {
  74                    if !value.is_empty() {
  75                        (true, true)
  76                    } else {
  77                        (is_authenticated, false)
  78                    }
  79                } else {
  80                    (is_authenticated, false)
  81                }
  82            } else {
  83                (is_authenticated, false)
  84            };
  85
  86        let state = cx.new(|_| ExtensionLlmProviderState {
  87            is_authenticated,
  88            available_models: models,
  89            env_var_allowed,
  90            api_key_from_env,
  91        });
  92
  93        Self {
  94            extension,
  95            provider_info,
  96            icon_path,
  97            auth_config,
  98            state,
  99        }
 100    }
 101
 102    fn provider_id_string(&self) -> String {
 103        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 104    }
 105
 106    /// The credential key used for storing the API key in the system keychain.
 107    fn credential_key(&self) -> String {
 108        format!("extension-llm-{}", self.provider_id_string())
 109    }
 110}
 111
 112impl LanguageModelProvider for ExtensionLanguageModelProvider {
 113    fn id(&self) -> LanguageModelProviderId {
 114        LanguageModelProviderId::from(self.provider_id_string())
 115    }
 116
 117    fn name(&self) -> LanguageModelProviderName {
 118        LanguageModelProviderName::from(self.provider_info.name.clone())
 119    }
 120
 121    fn icon(&self) -> ui::IconName {
 122        ui::IconName::ZedAssistant
 123    }
 124
 125    fn icon_path(&self) -> Option<SharedString> {
 126        self.icon_path.clone()
 127    }
 128
 129    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 130        let state = self.state.read(cx);
 131        state
 132            .available_models
 133            .iter()
 134            .find(|m| m.is_default)
 135            .or_else(|| state.available_models.first())
 136            .map(|model_info| {
 137                Arc::new(ExtensionLanguageModel {
 138                    extension: self.extension.clone(),
 139                    model_info: model_info.clone(),
 140                    provider_id: self.id(),
 141                    provider_name: self.name(),
 142                    provider_info: self.provider_info.clone(),
 143                }) as Arc<dyn LanguageModel>
 144            })
 145    }
 146
 147    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 148        let state = self.state.read(cx);
 149        state
 150            .available_models
 151            .iter()
 152            .find(|m| m.is_default_fast)
 153            .map(|model_info| {
 154                Arc::new(ExtensionLanguageModel {
 155                    extension: self.extension.clone(),
 156                    model_info: model_info.clone(),
 157                    provider_id: self.id(),
 158                    provider_name: self.name(),
 159                    provider_info: self.provider_info.clone(),
 160                }) as Arc<dyn LanguageModel>
 161            })
 162    }
 163
 164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 165        let state = self.state.read(cx);
 166        state
 167            .available_models
 168            .iter()
 169            .map(|model_info| {
 170                Arc::new(ExtensionLanguageModel {
 171                    extension: self.extension.clone(),
 172                    model_info: model_info.clone(),
 173                    provider_id: self.id(),
 174                    provider_name: self.name(),
 175                    provider_info: self.provider_info.clone(),
 176                }) as Arc<dyn LanguageModel>
 177            })
 178            .collect()
 179    }
 180
 181    fn is_authenticated(&self, cx: &App) -> bool {
 182        self.state.read(cx).is_authenticated
 183    }
 184
 185    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 186        // Authentication is handled via the configuration view UI
 187        Task::ready(Ok(()))
 188    }
 189
 190    fn configuration_view(
 191        &self,
 192        _target_agent: ConfigurationViewTargetAgent,
 193        window: &mut Window,
 194        cx: &mut App,
 195    ) -> AnyView {
 196        let credential_key = self.credential_key();
 197        let extension = self.extension.clone();
 198        let extension_provider_id = self.provider_info.id.clone();
 199        let full_provider_id = self.provider_id_string();
 200        let state = self.state.clone();
 201        let auth_config = self.auth_config.clone();
 202
 203        cx.new(|cx| {
 204            ExtensionProviderConfigurationView::new(
 205                credential_key,
 206                extension,
 207                extension_provider_id,
 208                full_provider_id,
 209                auth_config,
 210                state,
 211                window,
 212                cx,
 213            )
 214        })
 215        .into()
 216    }
 217
 218    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 219        let extension = self.extension.clone();
 220        let provider_id = self.provider_info.id.clone();
 221        let state = self.state.clone();
 222        let credential_key = self.credential_key();
 223
 224        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 225
 226        cx.spawn(async move |cx| {
 227            // Delete from system keychain
 228            credentials_provider
 229                .delete_credentials(&credential_key, cx)
 230                .await
 231                .log_err();
 232
 233            // Call extension's reset_credentials
 234            let result = extension
 235                .call(|extension, store| {
 236                    async move {
 237                        extension
 238                            .call_llm_provider_reset_credentials(store, &provider_id)
 239                            .await
 240                    }
 241                    .boxed()
 242                })
 243                .await;
 244
 245            // Update state
 246            cx.update(|cx| {
 247                state.update(cx, |state, _| {
 248                    state.is_authenticated = false;
 249                });
 250            })?;
 251
 252            match result {
 253                Ok(Ok(Ok(()))) => Ok(()),
 254                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 255                Ok(Err(e)) => Err(e),
 256                Err(e) => Err(e),
 257            }
 258        })
 259    }
 260}
 261
 262impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 263    type ObservableEntity = ExtensionLlmProviderState;
 264
 265    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 266        Some(self.state.clone())
 267    }
 268
 269    fn subscribe<T: 'static>(
 270        &self,
 271        cx: &mut Context<T>,
 272        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 273    ) -> Option<Subscription> {
 274        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 275    }
 276}
 277
 278/// Configuration view for extension-based LLM providers.
 279struct ExtensionProviderConfigurationView {
 280    credential_key: String,
 281    extension: WasmExtension,
 282    extension_provider_id: String,
 283    full_provider_id: String,
 284    auth_config: Option<LanguageModelAuthConfig>,
 285    state: Entity<ExtensionLlmProviderState>,
 286    settings_markdown: Option<Entity<Markdown>>,
 287    api_key_editor: Entity<Editor>,
 288    loading_settings: bool,
 289    loading_credentials: bool,
 290    _subscriptions: Vec<Subscription>,
 291}
 292
 293impl ExtensionProviderConfigurationView {
 294    fn new(
 295        credential_key: String,
 296        extension: WasmExtension,
 297        extension_provider_id: String,
 298        full_provider_id: String,
 299        auth_config: Option<LanguageModelAuthConfig>,
 300        state: Entity<ExtensionLlmProviderState>,
 301        window: &mut Window,
 302        cx: &mut Context<Self>,
 303    ) -> Self {
 304        // Subscribe to state changes
 305        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 306            cx.notify();
 307        });
 308
 309        // Create API key editor
 310        let api_key_editor = cx.new(|cx| {
 311            let mut editor = Editor::single_line(window, cx);
 312            editor.set_placeholder_text("Enter API key...", window, cx);
 313            editor
 314        });
 315
 316        let mut this = Self {
 317            credential_key,
 318            extension,
 319            extension_provider_id,
 320            full_provider_id,
 321            auth_config,
 322            state,
 323            settings_markdown: None,
 324            api_key_editor,
 325            loading_settings: true,
 326            loading_credentials: true,
 327            _subscriptions: vec![state_subscription],
 328        };
 329
 330        // Load settings text from extension
 331        this.load_settings_text(cx);
 332
 333        // Load existing credentials
 334        this.load_credentials(cx);
 335
 336        this
 337    }
 338
 339    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 340        let extension = self.extension.clone();
 341        let provider_id = self.extension_provider_id.clone();
 342
 343        cx.spawn(async move |this, cx| {
 344            let result = extension
 345                .call({
 346                    let provider_id = provider_id.clone();
 347                    |ext, store| {
 348                        async move {
 349                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 350                                .await
 351                        }
 352                        .boxed()
 353                    }
 354                })
 355                .await;
 356
 357            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 358
 359            this.update(cx, |this, cx| {
 360                this.loading_settings = false;
 361                if let Some(text) = settings_text {
 362                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 363                    this.settings_markdown = Some(markdown);
 364                }
 365                cx.notify();
 366            })
 367            .log_err();
 368        })
 369        .detach();
 370    }
 371
 372    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 373        let credential_key = self.credential_key.clone();
 374        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 375        let state = self.state.clone();
 376
 377        // Check if we should use env var (already set in state during provider construction)
 378        let api_key_from_env = self.state.read(cx).api_key_from_env;
 379
 380        cx.spawn(async move |this, cx| {
 381            // If using env var, we're already authenticated
 382            if api_key_from_env {
 383                this.update(cx, |this, cx| {
 384                    this.loading_credentials = false;
 385                    cx.notify();
 386                })
 387                .log_err();
 388                return;
 389            }
 390
 391            let credentials = credentials_provider
 392                .read_credentials(&credential_key, cx)
 393                .await
 394                .log_err()
 395                .flatten();
 396
 397            let has_credentials = credentials.is_some();
 398
 399            // Update authentication state based on stored credentials
 400            let _ = cx.update(|cx| {
 401                state.update(cx, |state, cx| {
 402                    state.is_authenticated = has_credentials;
 403                    cx.notify();
 404                });
 405            });
 406
 407            this.update(cx, |this, cx| {
 408                this.loading_credentials = false;
 409                cx.notify();
 410            })
 411            .log_err();
 412        })
 413        .detach();
 414    }
 415
 416    fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
 417        let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
 418        let env_var_name = match &self.auth_config {
 419            Some(config) => config.env_var.clone(),
 420            None => return,
 421        };
 422
 423        let state = self.state.clone();
 424        let currently_allowed = self.state.read(cx).env_var_allowed;
 425
 426        // Update settings file
 427        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
 428            let providers = settings
 429                .extension
 430                .allowed_env_var_providers
 431                .get_or_insert_with(Vec::new);
 432
 433            if currently_allowed {
 434                providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
 435            } else {
 436                if !providers
 437                    .iter()
 438                    .any(|id| id.as_ref() == full_provider_id.as_ref())
 439                {
 440                    providers.push(full_provider_id.clone());
 441                }
 442            }
 443        });
 444
 445        // Update local state
 446        let new_allowed = !currently_allowed;
 447        let new_from_env = if new_allowed {
 448            if let Some(var_name) = &env_var_name {
 449                if let Ok(value) = std::env::var(var_name) {
 450                    !value.is_empty()
 451                } else {
 452                    false
 453                }
 454            } else {
 455                false
 456            }
 457        } else {
 458            false
 459        };
 460
 461        state.update(cx, |state, cx| {
 462            state.env_var_allowed = new_allowed;
 463            state.api_key_from_env = new_from_env;
 464            if new_from_env {
 465                state.is_authenticated = true;
 466            }
 467            cx.notify();
 468        });
 469
 470        // If env var is being enabled, clear any stored keychain credentials
 471        // so there's only one source of truth for the API key
 472        if new_allowed {
 473            let credential_key = self.credential_key.clone();
 474            let credentials_provider = <dyn CredentialsProvider>::global(cx);
 475            cx.spawn(async move |_this, cx| {
 476                credentials_provider
 477                    .delete_credentials(&credential_key, cx)
 478                    .await
 479                    .log_err();
 480            })
 481            .detach();
 482        }
 483
 484        // If env var is being disabled, reload credentials from keychain
 485        if !new_allowed {
 486            self.reload_keychain_credentials(cx);
 487        }
 488
 489        cx.notify();
 490    }
 491
 492    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 493        let credential_key = self.credential_key.clone();
 494        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 495        let state = self.state.clone();
 496
 497        cx.spawn(async move |_this, cx| {
 498            let credentials = credentials_provider
 499                .read_credentials(&credential_key, cx)
 500                .await
 501                .log_err()
 502                .flatten();
 503
 504            let has_credentials = credentials.is_some();
 505
 506            let _ = cx.update(|cx| {
 507                state.update(cx, |state, cx| {
 508                    state.is_authenticated = has_credentials;
 509                    cx.notify();
 510                });
 511            });
 512        })
 513        .detach();
 514    }
 515
 516    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 517        let api_key = self.api_key_editor.read(cx).text(cx);
 518        if api_key.is_empty() {
 519            return;
 520        }
 521
 522        // Clear the editor
 523        self.api_key_editor
 524            .update(cx, |editor, cx| editor.set_text("", window, cx));
 525
 526        let credential_key = self.credential_key.clone();
 527        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 528        let state = self.state.clone();
 529
 530        cx.spawn(async move |_this, cx| {
 531            // Store in system keychain
 532            credentials_provider
 533                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 534                .await
 535                .log_err();
 536
 537            // Update state to authenticated
 538            let _ = cx.update(|cx| {
 539                state.update(cx, |state, cx| {
 540                    state.is_authenticated = true;
 541                    cx.notify();
 542                });
 543            });
 544        })
 545        .detach();
 546    }
 547
 548    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 549        // Clear the editor
 550        self.api_key_editor
 551            .update(cx, |editor, cx| editor.set_text("", window, cx));
 552
 553        let credential_key = self.credential_key.clone();
 554        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 555        let state = self.state.clone();
 556
 557        cx.spawn(async move |_this, cx| {
 558            // Delete from system keychain
 559            credentials_provider
 560                .delete_credentials(&credential_key, cx)
 561                .await
 562                .log_err();
 563
 564            // Update state to unauthenticated
 565            let _ = cx.update(|cx| {
 566                state.update(cx, |state, cx| {
 567                    state.is_authenticated = false;
 568                    cx.notify();
 569                });
 570            });
 571        })
 572        .detach();
 573    }
 574
 575    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 576        self.state.read(cx).is_authenticated
 577    }
 578}
 579
 580impl gpui::Render for ExtensionProviderConfigurationView {
 581    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 582        let is_loading = self.loading_settings || self.loading_credentials;
 583        let is_authenticated = self.is_authenticated(cx);
 584        let env_var_allowed = self.state.read(cx).env_var_allowed;
 585        let api_key_from_env = self.state.read(cx).api_key_from_env;
 586
 587        if is_loading {
 588            return v_flex()
 589                .gap_2()
 590                .child(Label::new("Loading...").color(Color::Muted))
 591                .into_any_element();
 592        }
 593
 594        let mut content = v_flex().gap_4().size_full();
 595
 596        // Render settings markdown if available
 597        if let Some(markdown) = &self.settings_markdown {
 598            let style = settings_markdown_style(_window, cx);
 599            content = content.child(
 600                div()
 601                    .p_2()
 602                    .rounded_md()
 603                    .bg(cx.theme().colors().surface_background)
 604                    .child(MarkdownElement::new(markdown.clone(), style)),
 605            );
 606        }
 607
 608        // Render env var checkbox if the extension specifies an env var
 609        if let Some(auth_config) = &self.auth_config {
 610            if let Some(env_var_name) = &auth_config.env_var {
 611                let env_var_name = env_var_name.clone();
 612                let checkbox_label =
 613                    format!("Read API key from {} environment variable", env_var_name);
 614
 615                content = content.child(
 616                    h_flex()
 617                        .gap_2()
 618                        .child(
 619                            ui::Checkbox::new("env-var-permission", env_var_allowed.into())
 620                                .on_click(cx.listener(|this, _, _window, cx| {
 621                                    this.toggle_env_var_permission(cx);
 622                                })),
 623                        )
 624                        .child(Label::new(checkbox_label).size(LabelSize::Small)),
 625                );
 626
 627                // Show status if env var is allowed
 628                if env_var_allowed {
 629                    if api_key_from_env {
 630                        content = content.child(
 631                            h_flex()
 632                                .gap_2()
 633                                .child(
 634                                    ui::Icon::new(ui::IconName::Check)
 635                                        .color(Color::Success)
 636                                        .size(ui::IconSize::Small),
 637                                )
 638                                .child(
 639                                    Label::new(format!("API key loaded from {}", env_var_name))
 640                                        .color(Color::Success),
 641                                ),
 642                        );
 643                        return content.into_any_element();
 644                    } else {
 645                        content = content.child(
 646                            h_flex()
 647                                .gap_2()
 648                                .child(
 649                                    ui::Icon::new(ui::IconName::Warning)
 650                                        .color(Color::Warning)
 651                                        .size(ui::IconSize::Small),
 652                                )
 653                                .child(
 654                                    Label::new(format!(
 655                                        "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
 656                                        env_var_name
 657                                    ))
 658                                    .color(Color::Warning)
 659                                    .size(LabelSize::Small),
 660                                ),
 661                        );
 662                    }
 663                }
 664            }
 665        }
 666
 667        // Render API key section
 668        if is_authenticated && !api_key_from_env {
 669            content = content.child(
 670                v_flex()
 671                    .gap_2()
 672                    .child(
 673                        h_flex()
 674                            .gap_2()
 675                            .child(
 676                                ui::Icon::new(ui::IconName::Check)
 677                                    .color(Color::Success)
 678                                    .size(ui::IconSize::Small),
 679                            )
 680                            .child(Label::new("API key configured").color(Color::Success)),
 681                    )
 682                    .child(
 683                        ui::Button::new("reset-api-key", "Reset API Key")
 684                            .style(ui::ButtonStyle::Subtle)
 685                            .on_click(cx.listener(|this, _, window, cx| {
 686                                this.reset_api_key(window, cx);
 687                            })),
 688                    ),
 689            );
 690        } else if !api_key_from_env {
 691            let credential_label = self
 692                .auth_config
 693                .as_ref()
 694                .and_then(|c| c.credential_label.clone())
 695                .unwrap_or_else(|| "API Key".to_string());
 696
 697            content = content.child(
 698                v_flex()
 699                    .gap_2()
 700                    .on_action(cx.listener(Self::save_api_key))
 701                    .child(
 702                        Label::new(credential_label)
 703                            .size(LabelSize::Small)
 704                            .color(Color::Muted),
 705                    )
 706                    .child(self.api_key_editor.clone())
 707                    .child(
 708                        Label::new("Enter your API key and press Enter to save")
 709                            .size(LabelSize::Small)
 710                            .color(Color::Muted),
 711                    ),
 712            );
 713        }
 714
 715        content.into_any_element()
 716    }
 717}
 718
 719impl Focusable for ExtensionProviderConfigurationView {
 720    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 721        self.api_key_editor.focus_handle(cx)
 722    }
 723}
 724
 725fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
 726    let theme_settings = ThemeSettings::get_global(cx);
 727    let colors = cx.theme().colors();
 728    let mut text_style = window.text_style();
 729    text_style.refine(&TextStyleRefinement {
 730        font_family: Some(theme_settings.ui_font.family.clone()),
 731        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
 732        font_features: Some(theme_settings.ui_font.features.clone()),
 733        color: Some(colors.text),
 734        ..Default::default()
 735    });
 736
 737    MarkdownStyle {
 738        base_text_style: text_style,
 739        selection_background_color: colors.element_selection_background,
 740        inline_code: TextStyleRefinement {
 741            background_color: Some(colors.editor_background),
 742            ..Default::default()
 743        },
 744        link: TextStyleRefinement {
 745            color: Some(colors.text_accent),
 746            underline: Some(UnderlineStyle {
 747                color: Some(colors.text_accent.opacity(0.5)),
 748                thickness: px(1.),
 749                ..Default::default()
 750            }),
 751            ..Default::default()
 752        },
 753        syntax: cx.theme().syntax().clone(),
 754        ..Default::default()
 755    }
 756}
 757
 758/// An extension-based language model.
 759pub struct ExtensionLanguageModel {
 760    extension: WasmExtension,
 761    model_info: LlmModelInfo,
 762    provider_id: LanguageModelProviderId,
 763    provider_name: LanguageModelProviderName,
 764    provider_info: LlmProviderInfo,
 765}
 766
 767impl LanguageModel for ExtensionLanguageModel {
 768    fn id(&self) -> LanguageModelId {
 769        LanguageModelId::from(self.model_info.id.clone())
 770    }
 771
 772    fn name(&self) -> LanguageModelName {
 773        LanguageModelName::from(self.model_info.name.clone())
 774    }
 775
 776    fn provider_id(&self) -> LanguageModelProviderId {
 777        self.provider_id.clone()
 778    }
 779
 780    fn provider_name(&self) -> LanguageModelProviderName {
 781        self.provider_name.clone()
 782    }
 783
 784    fn telemetry_id(&self) -> String {
 785        format!("extension-{}", self.model_info.id)
 786    }
 787
 788    fn supports_images(&self) -> bool {
 789        self.model_info.capabilities.supports_images
 790    }
 791
 792    fn supports_tools(&self) -> bool {
 793        self.model_info.capabilities.supports_tools
 794    }
 795
 796    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 797        match choice {
 798            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
 799            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
 800            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
 801        }
 802    }
 803
 804    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 805        match self.model_info.capabilities.tool_input_format {
 806            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
 807            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
 808        }
 809    }
 810
 811    fn max_token_count(&self) -> u64 {
 812        self.model_info.max_token_count
 813    }
 814
 815    fn max_output_tokens(&self) -> Option<u64> {
 816        self.model_info.max_output_tokens
 817    }
 818
 819    fn count_tokens(
 820        &self,
 821        request: LanguageModelRequest,
 822        cx: &App,
 823    ) -> BoxFuture<'static, Result<u64>> {
 824        let extension = self.extension.clone();
 825        let provider_id = self.provider_info.id.clone();
 826        let model_id = self.model_info.id.clone();
 827
 828        let wit_request = convert_request_to_wit(request);
 829
 830        cx.background_spawn(async move {
 831            extension
 832                .call({
 833                    let provider_id = provider_id.clone();
 834                    let model_id = model_id.clone();
 835                    let wit_request = wit_request.clone();
 836                    |ext, store| {
 837                        async move {
 838                            let count = ext
 839                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
 840                                .await?
 841                                .map_err(|e| anyhow!("{}", e))?;
 842                            Ok(count)
 843                        }
 844                        .boxed()
 845                    }
 846                })
 847                .await?
 848        })
 849        .boxed()
 850    }
 851
 852    fn stream_completion(
 853        &self,
 854        request: LanguageModelRequest,
 855        _cx: &AsyncApp,
 856    ) -> BoxFuture<
 857        'static,
 858        Result<
 859            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 860            LanguageModelCompletionError,
 861        >,
 862    > {
 863        let extension = self.extension.clone();
 864        let provider_id = self.provider_info.id.clone();
 865        let model_id = self.model_info.id.clone();
 866
 867        let wit_request = convert_request_to_wit(request);
 868
 869        async move {
 870            // Start the stream
 871            let stream_id_result = extension
 872                .call({
 873                    let provider_id = provider_id.clone();
 874                    let model_id = model_id.clone();
 875                    let wit_request = wit_request.clone();
 876                    |ext, store| {
 877                        async move {
 878                            let id = ext
 879                                .call_llm_stream_completion_start(
 880                                    store,
 881                                    &provider_id,
 882                                    &model_id,
 883                                    &wit_request,
 884                                )
 885                                .await?
 886                                .map_err(|e| anyhow!("{}", e))?;
 887                            Ok(id)
 888                        }
 889                        .boxed()
 890                    }
 891                })
 892                .await;
 893
 894            let stream_id = stream_id_result
 895                .map_err(LanguageModelCompletionError::Other)?
 896                .map_err(LanguageModelCompletionError::Other)?;
 897
 898            // Create a stream that polls for events
 899            let stream = futures::stream::unfold(
 900                (extension.clone(), stream_id, false),
 901                move |(extension, stream_id, done)| async move {
 902                    if done {
 903                        return None;
 904                    }
 905
 906                    let result = extension
 907                        .call({
 908                            let stream_id = stream_id.clone();
 909                            |ext, store| {
 910                                async move {
 911                                    let event = ext
 912                                        .call_llm_stream_completion_next(store, &stream_id)
 913                                        .await?
 914                                        .map_err(|e| anyhow!("{}", e))?;
 915                                    Ok(event)
 916                                }
 917                                .boxed()
 918                            }
 919                        })
 920                        .await
 921                        .and_then(|inner| inner);
 922
 923                    match result {
 924                        Ok(Some(event)) => {
 925                            let converted = convert_completion_event(event);
 926                            let is_done =
 927                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
 928                            Some((converted, (extension, stream_id, is_done)))
 929                        }
 930                        Ok(None) => {
 931                            // Stream complete, close it
 932                            let _ = extension
 933                                .call({
 934                                    let stream_id = stream_id.clone();
 935                                    |ext, store| {
 936                                        async move {
 937                                            ext.call_llm_stream_completion_close(store, &stream_id)
 938                                                .await?;
 939                                            Ok::<(), anyhow::Error>(())
 940                                        }
 941                                        .boxed()
 942                                    }
 943                                })
 944                                .await;
 945                            None
 946                        }
 947                        Err(e) => Some((
 948                            Err(LanguageModelCompletionError::Other(e)),
 949                            (extension, stream_id, true),
 950                        )),
 951                    }
 952                },
 953            );
 954
 955            Ok(stream.boxed())
 956        }
 957        .boxed()
 958    }
 959
 960    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 961        // Extensions can implement this via llm_cache_configuration
 962        None
 963    }
 964}
 965
 966fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
 967    use language_model::{MessageContent, Role};
 968
 969    let messages: Vec<LlmRequestMessage> = request
 970        .messages
 971        .into_iter()
 972        .map(|msg| {
 973            let role = match msg.role {
 974                Role::User => LlmMessageRole::User,
 975                Role::Assistant => LlmMessageRole::Assistant,
 976                Role::System => LlmMessageRole::System,
 977            };
 978
 979            let content: Vec<LlmMessageContent> = msg
 980                .content
 981                .into_iter()
 982                .map(|c| match c {
 983                    MessageContent::Text(text) => LlmMessageContent::Text(text),
 984                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
 985                        source: image.source.to_string(),
 986                        width: Some(image.size.width.0 as u32),
 987                        height: Some(image.size.height.0 as u32),
 988                    }),
 989                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
 990                        id: tool_use.id.to_string(),
 991                        name: tool_use.name.to_string(),
 992                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
 993                        thought_signature: tool_use.thought_signature,
 994                    }),
 995                    MessageContent::ToolResult(tool_result) => {
 996                        let content = match tool_result.content {
 997                            language_model::LanguageModelToolResultContent::Text(text) => {
 998                                LlmToolResultContent::Text(text.to_string())
 999                            }
1000                            language_model::LanguageModelToolResultContent::Image(image) => {
1001                                LlmToolResultContent::Image(LlmImageData {
1002                                    source: image.source.to_string(),
1003                                    width: Some(image.size.width.0 as u32),
1004                                    height: Some(image.size.height.0 as u32),
1005                                })
1006                            }
1007                        };
1008                        LlmMessageContent::ToolResult(LlmToolResult {
1009                            tool_use_id: tool_result.tool_use_id.to_string(),
1010                            tool_name: tool_result.tool_name.to_string(),
1011                            is_error: tool_result.is_error,
1012                            content,
1013                        })
1014                    }
1015                    MessageContent::Thinking { text, signature } => {
1016                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1017                    }
1018                    MessageContent::RedactedThinking(data) => {
1019                        LlmMessageContent::RedactedThinking(data)
1020                    }
1021                })
1022                .collect();
1023
1024            LlmRequestMessage {
1025                role,
1026                content,
1027                cache: msg.cache,
1028            }
1029        })
1030        .collect();
1031
1032    let tools: Vec<LlmToolDefinition> = request
1033        .tools
1034        .into_iter()
1035        .map(|tool| LlmToolDefinition {
1036            name: tool.name,
1037            description: tool.description,
1038            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1039        })
1040        .collect();
1041
1042    let tool_choice = request.tool_choice.map(|tc| match tc {
1043        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1044        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1045        LanguageModelToolChoice::None => LlmToolChoice::None,
1046    });
1047
1048    LlmCompletionRequest {
1049        messages,
1050        tools,
1051        tool_choice,
1052        stop_sequences: request.stop,
1053        temperature: request.temperature,
1054        thinking_allowed: false,
1055        max_tokens: None,
1056    }
1057}
1058
1059fn convert_completion_event(
1060    event: LlmCompletionEvent,
1061) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1062    match event {
1063        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1064            message_id: String::new(),
1065        }),
1066        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1067        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1068            text: thinking.text,
1069            signature: thinking.signature,
1070        }),
1071        LlmCompletionEvent::RedactedThinking(data) => {
1072            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1073        }
1074        LlmCompletionEvent::ToolUse(tool_use) => {
1075            let raw_input = tool_use.input.clone();
1076            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1077            Ok(LanguageModelCompletionEvent::ToolUse(
1078                LanguageModelToolUse {
1079                    id: LanguageModelToolUseId::from(tool_use.id),
1080                    name: tool_use.name.into(),
1081                    raw_input,
1082                    input,
1083                    is_input_complete: true,
1084                    thought_signature: tool_use.thought_signature,
1085                },
1086            ))
1087        }
1088        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1089            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1090                id: LanguageModelToolUseId::from(error.id),
1091                tool_name: error.tool_name.into(),
1092                raw_input: error.raw_input.into(),
1093                json_parse_error: error.error,
1094            })
1095        }
1096        LlmCompletionEvent::Stop(reason) => {
1097            let stop_reason = match reason {
1098                LlmStopReason::EndTurn => StopReason::EndTurn,
1099                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1100                LlmStopReason::ToolUse => StopReason::ToolUse,
1101                LlmStopReason::Refusal => StopReason::Refusal,
1102            };
1103            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1104        }
1105        LlmCompletionEvent::Usage(usage) => {
1106            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1107                input_tokens: usage.input_tokens,
1108                output_tokens: usage.output_tokens,
1109                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1110                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1111            }))
1112        }
1113        LlmCompletionEvent::ReasoningDetails(json) => {
1114            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1115                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1116            ))
1117        }
1118    }
1119}