llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::wasm_host::WasmExtension;
   3
   4use crate::wasm_host::wit::{
   5    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   6    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
   7    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
   8    LlmToolUse,
   9};
  10use anyhow::{Result, anyhow};
  11use credentials_provider::CredentialsProvider;
  12use editor::Editor;
  13use extension::LanguageModelAuthConfig;
  14use futures::future::BoxFuture;
  15use futures::stream::BoxStream;
  16use futures::{FutureExt, StreamExt};
  17use gpui::Focusable;
  18use gpui::{
  19    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
  20    TextStyleRefinement, UnderlineStyle, Window, px,
  21};
  22use language_model::tool_schema::LanguageModelToolSchemaFormat;
  23use language_model::{
  24    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  25    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  27    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  28    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  29};
  30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  31use settings::Settings;
  32use std::sync::Arc;
  33use theme::ThemeSettings;
  34use ui::{Label, LabelSize, prelude::*};
  35use util::ResultExt as _;
  36
  37/// An extension-based language model provider.
  38pub struct ExtensionLanguageModelProvider {
  39    pub extension: WasmExtension,
  40    pub provider_info: LlmProviderInfo,
  41    icon_path: Option<SharedString>,
  42    auth_config: Option<LanguageModelAuthConfig>,
  43    state: Entity<ExtensionLlmProviderState>,
  44}
  45
  46pub struct ExtensionLlmProviderState {
  47    is_authenticated: bool,
  48    available_models: Vec<LlmModelInfo>,
  49    env_var_allowed: bool,
  50    api_key_from_env: bool,
  51}
  52
  53impl EventEmitter<()> for ExtensionLlmProviderState {}
  54
  55impl ExtensionLanguageModelProvider {
  56    pub fn new(
  57        extension: WasmExtension,
  58        provider_info: LlmProviderInfo,
  59        models: Vec<LlmModelInfo>,
  60        is_authenticated: bool,
  61        icon_path: Option<SharedString>,
  62        auth_config: Option<LanguageModelAuthConfig>,
  63        cx: &mut App,
  64    ) -> Self {
  65        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  66        let env_var_allowed = ExtensionSettings::get_global(cx)
  67            .allowed_env_var_providers
  68            .contains(provider_id_string.as_str());
  69
  70        let (is_authenticated, api_key_from_env) =
  71            if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
  72                let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
  73                if let Ok(value) = std::env::var(env_var_name) {
  74                    if !value.is_empty() {
  75                        (true, true)
  76                    } else {
  77                        (is_authenticated, false)
  78                    }
  79                } else {
  80                    (is_authenticated, false)
  81                }
  82            } else {
  83                (is_authenticated, false)
  84            };
  85
  86        let state = cx.new(|_| ExtensionLlmProviderState {
  87            is_authenticated,
  88            available_models: models,
  89            env_var_allowed,
  90            api_key_from_env,
  91        });
  92
  93        Self {
  94            extension,
  95            provider_info,
  96            icon_path,
  97            auth_config,
  98            state,
  99        }
 100    }
 101
 102    fn provider_id_string(&self) -> String {
 103        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 104    }
 105
 106    /// The credential key used for storing the API key in the system keychain.
 107    fn credential_key(&self) -> String {
 108        format!("extension-llm-{}", self.provider_id_string())
 109    }
 110}
 111
 112impl LanguageModelProvider for ExtensionLanguageModelProvider {
 113    fn id(&self) -> LanguageModelProviderId {
 114        LanguageModelProviderId::from(self.provider_id_string())
 115    }
 116
 117    fn name(&self) -> LanguageModelProviderName {
 118        LanguageModelProviderName::from(self.provider_info.name.clone())
 119    }
 120
 121    fn icon(&self) -> ui::IconName {
 122        ui::IconName::ZedAssistant
 123    }
 124
 125    fn icon_path(&self) -> Option<SharedString> {
 126        self.icon_path.clone()
 127    }
 128
 129    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 130        let state = self.state.read(cx);
 131        state
 132            .available_models
 133            .iter()
 134            .find(|m| m.is_default)
 135            .or_else(|| state.available_models.first())
 136            .map(|model_info| {
 137                Arc::new(ExtensionLanguageModel {
 138                    extension: self.extension.clone(),
 139                    model_info: model_info.clone(),
 140                    provider_id: self.id(),
 141                    provider_name: self.name(),
 142                    provider_info: self.provider_info.clone(),
 143                }) as Arc<dyn LanguageModel>
 144            })
 145    }
 146
 147    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 148        let state = self.state.read(cx);
 149        state
 150            .available_models
 151            .iter()
 152            .find(|m| m.is_default_fast)
 153            .map(|model_info| {
 154                Arc::new(ExtensionLanguageModel {
 155                    extension: self.extension.clone(),
 156                    model_info: model_info.clone(),
 157                    provider_id: self.id(),
 158                    provider_name: self.name(),
 159                    provider_info: self.provider_info.clone(),
 160                }) as Arc<dyn LanguageModel>
 161            })
 162    }
 163
 164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 165        let state = self.state.read(cx);
 166        state
 167            .available_models
 168            .iter()
 169            .map(|model_info| {
 170                Arc::new(ExtensionLanguageModel {
 171                    extension: self.extension.clone(),
 172                    model_info: model_info.clone(),
 173                    provider_id: self.id(),
 174                    provider_name: self.name(),
 175                    provider_info: self.provider_info.clone(),
 176                }) as Arc<dyn LanguageModel>
 177            })
 178            .collect()
 179    }
 180
 181    fn is_authenticated(&self, cx: &App) -> bool {
 182        self.state.read(cx).is_authenticated
 183    }
 184
 185    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 186        let extension = self.extension.clone();
 187        let provider_id = self.provider_info.id.clone();
 188        let state = self.state.clone();
 189
 190        cx.spawn(async move |cx| {
 191            let result = extension
 192                .call(|extension, store| {
 193                    async move {
 194                        extension
 195                            .call_llm_provider_authenticate(store, &provider_id)
 196                            .await
 197                    }
 198                    .boxed()
 199                })
 200                .await;
 201
 202            match result {
 203                Ok(Ok(Ok(()))) => {
 204                    cx.update(|cx| {
 205                        state.update(cx, |state, _| {
 206                            state.is_authenticated = true;
 207                        });
 208                    })?;
 209                    Ok(())
 210                }
 211                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
 212                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
 213                Err(e) => Err(AuthenticateError::Other(e)),
 214            }
 215        })
 216    }
 217
 218    fn configuration_view(
 219        &self,
 220        _target_agent: ConfigurationViewTargetAgent,
 221        window: &mut Window,
 222        cx: &mut App,
 223    ) -> AnyView {
 224        let credential_key = self.credential_key();
 225        let extension = self.extension.clone();
 226        let extension_provider_id = self.provider_info.id.clone();
 227        let full_provider_id = self.provider_id_string();
 228        let state = self.state.clone();
 229        let auth_config = self.auth_config.clone();
 230
 231        cx.new(|cx| {
 232            ExtensionProviderConfigurationView::new(
 233                credential_key,
 234                extension,
 235                extension_provider_id,
 236                full_provider_id,
 237                auth_config,
 238                state,
 239                window,
 240                cx,
 241            )
 242        })
 243        .into()
 244    }
 245
 246    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 247        let extension = self.extension.clone();
 248        let provider_id = self.provider_info.id.clone();
 249        let state = self.state.clone();
 250        let credential_key = self.credential_key();
 251
 252        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 253
 254        cx.spawn(async move |cx| {
 255            // Delete from system keychain
 256            credentials_provider
 257                .delete_credentials(&credential_key, cx)
 258                .await
 259                .log_err();
 260
 261            // Call extension's reset_credentials
 262            let result = extension
 263                .call(|extension, store| {
 264                    async move {
 265                        extension
 266                            .call_llm_provider_reset_credentials(store, &provider_id)
 267                            .await
 268                    }
 269                    .boxed()
 270                })
 271                .await;
 272
 273            // Update state
 274            cx.update(|cx| {
 275                state.update(cx, |state, _| {
 276                    state.is_authenticated = false;
 277                });
 278            })?;
 279
 280            match result {
 281                Ok(Ok(Ok(()))) => Ok(()),
 282                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 283                Ok(Err(e)) => Err(e),
 284                Err(e) => Err(e),
 285            }
 286        })
 287    }
 288}
 289
 290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 291    type ObservableEntity = ExtensionLlmProviderState;
 292
 293    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 294        Some(self.state.clone())
 295    }
 296
 297    fn subscribe<T: 'static>(
 298        &self,
 299        cx: &mut Context<T>,
 300        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 301    ) -> Option<Subscription> {
 302        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 303    }
 304}
 305
 306/// Configuration view for extension-based LLM providers.
 307struct ExtensionProviderConfigurationView {
 308    credential_key: String,
 309    extension: WasmExtension,
 310    extension_provider_id: String,
 311    full_provider_id: String,
 312    auth_config: Option<LanguageModelAuthConfig>,
 313    state: Entity<ExtensionLlmProviderState>,
 314    settings_markdown: Option<Entity<Markdown>>,
 315    api_key_editor: Entity<Editor>,
 316    loading_settings: bool,
 317    loading_credentials: bool,
 318    _subscriptions: Vec<Subscription>,
 319}
 320
 321impl ExtensionProviderConfigurationView {
 322    fn new(
 323        credential_key: String,
 324        extension: WasmExtension,
 325        extension_provider_id: String,
 326        full_provider_id: String,
 327        auth_config: Option<LanguageModelAuthConfig>,
 328        state: Entity<ExtensionLlmProviderState>,
 329        window: &mut Window,
 330        cx: &mut Context<Self>,
 331    ) -> Self {
 332        // Subscribe to state changes
 333        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 334            cx.notify();
 335        });
 336
 337        // Create API key editor
 338        let api_key_editor = cx.new(|cx| {
 339            let mut editor = Editor::single_line(window, cx);
 340            editor.set_placeholder_text("Enter API key...", window, cx);
 341            editor
 342        });
 343
 344        let mut this = Self {
 345            credential_key,
 346            extension,
 347            extension_provider_id,
 348            full_provider_id,
 349            auth_config,
 350            state,
 351            settings_markdown: None,
 352            api_key_editor,
 353            loading_settings: true,
 354            loading_credentials: true,
 355            _subscriptions: vec![state_subscription],
 356        };
 357
 358        // Load settings text from extension
 359        this.load_settings_text(cx);
 360
 361        // Load existing credentials
 362        this.load_credentials(cx);
 363
 364        this
 365    }
 366
 367    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 368        let extension = self.extension.clone();
 369        let provider_id = self.extension_provider_id.clone();
 370
 371        cx.spawn(async move |this, cx| {
 372            let result = extension
 373                .call({
 374                    let provider_id = provider_id.clone();
 375                    |ext, store| {
 376                        async move {
 377                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 378                                .await
 379                        }
 380                        .boxed()
 381                    }
 382                })
 383                .await;
 384
 385            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 386
 387            this.update(cx, |this, cx| {
 388                this.loading_settings = false;
 389                if let Some(text) = settings_text {
 390                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 391                    this.settings_markdown = Some(markdown);
 392                }
 393                cx.notify();
 394            })
 395            .log_err();
 396        })
 397        .detach();
 398    }
 399
 400    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 401        let credential_key = self.credential_key.clone();
 402        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 403        let state = self.state.clone();
 404
 405        // Check if we should use env var (already set in state during provider construction)
 406        let api_key_from_env = self.state.read(cx).api_key_from_env;
 407
 408        cx.spawn(async move |this, cx| {
 409            // If using env var, we're already authenticated
 410            if api_key_from_env {
 411                this.update(cx, |this, cx| {
 412                    this.loading_credentials = false;
 413                    cx.notify();
 414                })
 415                .log_err();
 416                return;
 417            }
 418
 419            let credentials = credentials_provider
 420                .read_credentials(&credential_key, cx)
 421                .await
 422                .log_err()
 423                .flatten();
 424
 425            let has_credentials = credentials.is_some();
 426
 427            // Update authentication state based on stored credentials
 428            let _ = cx.update(|cx| {
 429                state.update(cx, |state, cx| {
 430                    state.is_authenticated = has_credentials;
 431                    cx.notify();
 432                });
 433            });
 434
 435            this.update(cx, |this, cx| {
 436                this.loading_credentials = false;
 437                cx.notify();
 438            })
 439            .log_err();
 440        })
 441        .detach();
 442    }
 443
 444    fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
 445        let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
 446        let env_var_name = match &self.auth_config {
 447            Some(config) => config.env_var.clone(),
 448            None => return,
 449        };
 450
 451        let state = self.state.clone();
 452        let currently_allowed = self.state.read(cx).env_var_allowed;
 453
 454        // Update settings file
 455        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
 456            let providers = settings
 457                .extension
 458                .allowed_env_var_providers
 459                .get_or_insert_with(Vec::new);
 460
 461            if currently_allowed {
 462                providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
 463            } else {
 464                if !providers
 465                    .iter()
 466                    .any(|id| id.as_ref() == full_provider_id.as_ref())
 467                {
 468                    providers.push(full_provider_id.clone());
 469                }
 470            }
 471        });
 472
 473        // Update local state
 474        let new_allowed = !currently_allowed;
 475        let new_from_env = if new_allowed {
 476            if let Some(var_name) = &env_var_name {
 477                if let Ok(value) = std::env::var(var_name) {
 478                    !value.is_empty()
 479                } else {
 480                    false
 481                }
 482            } else {
 483                false
 484            }
 485        } else {
 486            false
 487        };
 488
 489        state.update(cx, |state, cx| {
 490            state.env_var_allowed = new_allowed;
 491            state.api_key_from_env = new_from_env;
 492            if new_from_env {
 493                state.is_authenticated = true;
 494            }
 495            cx.notify();
 496        });
 497
 498        // If env var is being disabled, reload credentials from keychain
 499        if !new_allowed {
 500            self.reload_keychain_credentials(cx);
 501        }
 502
 503        cx.notify();
 504    }
 505
 506    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 507        let credential_key = self.credential_key.clone();
 508        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 509        let state = self.state.clone();
 510
 511        cx.spawn(async move |_this, cx| {
 512            let credentials = credentials_provider
 513                .read_credentials(&credential_key, cx)
 514                .await
 515                .log_err()
 516                .flatten();
 517
 518            let has_credentials = credentials.is_some();
 519
 520            let _ = cx.update(|cx| {
 521                state.update(cx, |state, cx| {
 522                    state.is_authenticated = has_credentials;
 523                    cx.notify();
 524                });
 525            });
 526        })
 527        .detach();
 528    }
 529
 530    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 531        let api_key = self.api_key_editor.read(cx).text(cx);
 532        if api_key.is_empty() {
 533            return;
 534        }
 535
 536        // Clear the editor
 537        self.api_key_editor
 538            .update(cx, |editor, cx| editor.set_text("", window, cx));
 539
 540        let credential_key = self.credential_key.clone();
 541        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 542        let state = self.state.clone();
 543
 544        cx.spawn(async move |_this, cx| {
 545            // Store in system keychain
 546            credentials_provider
 547                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 548                .await
 549                .log_err();
 550
 551            // Update state to authenticated
 552            let _ = cx.update(|cx| {
 553                state.update(cx, |state, cx| {
 554                    state.is_authenticated = true;
 555                    cx.notify();
 556                });
 557            });
 558        })
 559        .detach();
 560    }
 561
 562    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 563        // Clear the editor
 564        self.api_key_editor
 565            .update(cx, |editor, cx| editor.set_text("", window, cx));
 566
 567        let credential_key = self.credential_key.clone();
 568        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 569        let state = self.state.clone();
 570
 571        cx.spawn(async move |_this, cx| {
 572            // Delete from system keychain
 573            credentials_provider
 574                .delete_credentials(&credential_key, cx)
 575                .await
 576                .log_err();
 577
 578            // Update state to unauthenticated
 579            let _ = cx.update(|cx| {
 580                state.update(cx, |state, cx| {
 581                    state.is_authenticated = false;
 582                    cx.notify();
 583                });
 584            });
 585        })
 586        .detach();
 587    }
 588
 589    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 590        self.state.read(cx).is_authenticated
 591    }
 592}
 593
 594impl gpui::Render for ExtensionProviderConfigurationView {
 595    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 596        let is_loading = self.loading_settings || self.loading_credentials;
 597        let is_authenticated = self.is_authenticated(cx);
 598        let env_var_allowed = self.state.read(cx).env_var_allowed;
 599        let api_key_from_env = self.state.read(cx).api_key_from_env;
 600
 601        if is_loading {
 602            return v_flex()
 603                .gap_2()
 604                .child(Label::new("Loading...").color(Color::Muted))
 605                .into_any_element();
 606        }
 607
 608        let mut content = v_flex().gap_4().size_full();
 609
 610        // Render settings markdown if available
 611        if let Some(markdown) = &self.settings_markdown {
 612            let style = settings_markdown_style(_window, cx);
 613            content = content.child(
 614                div()
 615                    .p_2()
 616                    .rounded_md()
 617                    .bg(cx.theme().colors().surface_background)
 618                    .child(MarkdownElement::new(markdown.clone(), style)),
 619            );
 620        }
 621
 622        // Render env var checkbox if the extension specifies an env var
 623        if let Some(auth_config) = &self.auth_config {
 624            if let Some(env_var_name) = &auth_config.env_var {
 625                let env_var_name = env_var_name.clone();
 626                let checkbox_label =
 627                    format!("Read API key from {} environment variable", env_var_name);
 628
 629                content = content.child(
 630                    h_flex()
 631                        .gap_2()
 632                        .child(
 633                            ui::Checkbox::new("env-var-permission", env_var_allowed.into())
 634                                .on_click(cx.listener(|this, _, _window, cx| {
 635                                    this.toggle_env_var_permission(cx);
 636                                })),
 637                        )
 638                        .child(Label::new(checkbox_label).size(LabelSize::Small)),
 639                );
 640
 641                // Show status if env var is allowed
 642                if env_var_allowed {
 643                    if api_key_from_env {
 644                        content = content.child(
 645                            h_flex()
 646                                .gap_2()
 647                                .child(
 648                                    ui::Icon::new(ui::IconName::Check)
 649                                        .color(Color::Success)
 650                                        .size(ui::IconSize::Small),
 651                                )
 652                                .child(
 653                                    Label::new(format!("API key loaded from {}", env_var_name))
 654                                        .color(Color::Success),
 655                                ),
 656                        );
 657                        return content.into_any_element();
 658                    } else {
 659                        content = content.child(
 660                            h_flex()
 661                                .gap_2()
 662                                .child(
 663                                    ui::Icon::new(ui::IconName::Warning)
 664                                        .color(Color::Warning)
 665                                        .size(ui::IconSize::Small),
 666                                )
 667                                .child(
 668                                    Label::new(format!(
 669                                        "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
 670                                        env_var_name
 671                                    ))
 672                                    .color(Color::Warning)
 673                                    .size(LabelSize::Small),
 674                                ),
 675                        );
 676                    }
 677                }
 678            }
 679        }
 680
 681        // Render API key section
 682        if is_authenticated && !api_key_from_env {
 683            content = content.child(
 684                v_flex()
 685                    .gap_2()
 686                    .child(
 687                        h_flex()
 688                            .gap_2()
 689                            .child(
 690                                ui::Icon::new(ui::IconName::Check)
 691                                    .color(Color::Success)
 692                                    .size(ui::IconSize::Small),
 693                            )
 694                            .child(Label::new("API key configured").color(Color::Success)),
 695                    )
 696                    .child(
 697                        ui::Button::new("reset-api-key", "Reset API Key")
 698                            .style(ui::ButtonStyle::Subtle)
 699                            .on_click(cx.listener(|this, _, window, cx| {
 700                                this.reset_api_key(window, cx);
 701                            })),
 702                    ),
 703            );
 704        } else if !api_key_from_env {
 705            let credential_label = self
 706                .auth_config
 707                .as_ref()
 708                .and_then(|c| c.credential_label.clone())
 709                .unwrap_or_else(|| "API Key".to_string());
 710
 711            content = content.child(
 712                v_flex()
 713                    .gap_2()
 714                    .on_action(cx.listener(Self::save_api_key))
 715                    .child(
 716                        Label::new(credential_label)
 717                            .size(LabelSize::Small)
 718                            .color(Color::Muted),
 719                    )
 720                    .child(self.api_key_editor.clone())
 721                    .child(
 722                        Label::new("Enter your API key and press Enter to save")
 723                            .size(LabelSize::Small)
 724                            .color(Color::Muted),
 725                    ),
 726            );
 727        }
 728
 729        content.into_any_element()
 730    }
 731}
 732
 733impl Focusable for ExtensionProviderConfigurationView {
 734    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 735        self.api_key_editor.focus_handle(cx)
 736    }
 737}
 738
 739fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
 740    let theme_settings = ThemeSettings::get_global(cx);
 741    let colors = cx.theme().colors();
 742    let mut text_style = window.text_style();
 743    text_style.refine(&TextStyleRefinement {
 744        font_family: Some(theme_settings.ui_font.family.clone()),
 745        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
 746        font_features: Some(theme_settings.ui_font.features.clone()),
 747        color: Some(colors.text),
 748        ..Default::default()
 749    });
 750
 751    MarkdownStyle {
 752        base_text_style: text_style,
 753        selection_background_color: colors.element_selection_background,
 754        inline_code: TextStyleRefinement {
 755            background_color: Some(colors.editor_background),
 756            ..Default::default()
 757        },
 758        link: TextStyleRefinement {
 759            color: Some(colors.text_accent),
 760            underline: Some(UnderlineStyle {
 761                color: Some(colors.text_accent.opacity(0.5)),
 762                thickness: px(1.),
 763                ..Default::default()
 764            }),
 765            ..Default::default()
 766        },
 767        syntax: cx.theme().syntax().clone(),
 768        ..Default::default()
 769    }
 770}
 771
 772/// An extension-based language model.
 773pub struct ExtensionLanguageModel {
 774    extension: WasmExtension,
 775    model_info: LlmModelInfo,
 776    provider_id: LanguageModelProviderId,
 777    provider_name: LanguageModelProviderName,
 778    provider_info: LlmProviderInfo,
 779}
 780
 781impl LanguageModel for ExtensionLanguageModel {
 782    fn id(&self) -> LanguageModelId {
 783        LanguageModelId::from(self.model_info.id.clone())
 784    }
 785
 786    fn name(&self) -> LanguageModelName {
 787        LanguageModelName::from(self.model_info.name.clone())
 788    }
 789
 790    fn provider_id(&self) -> LanguageModelProviderId {
 791        self.provider_id.clone()
 792    }
 793
 794    fn provider_name(&self) -> LanguageModelProviderName {
 795        self.provider_name.clone()
 796    }
 797
 798    fn telemetry_id(&self) -> String {
 799        format!("extension-{}", self.model_info.id)
 800    }
 801
 802    fn supports_images(&self) -> bool {
 803        self.model_info.capabilities.supports_images
 804    }
 805
 806    fn supports_tools(&self) -> bool {
 807        self.model_info.capabilities.supports_tools
 808    }
 809
 810    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 811        match choice {
 812            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
 813            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
 814            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
 815        }
 816    }
 817
 818    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 819        match self.model_info.capabilities.tool_input_format {
 820            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
 821            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
 822        }
 823    }
 824
 825    fn max_token_count(&self) -> u64 {
 826        self.model_info.max_token_count
 827    }
 828
 829    fn max_output_tokens(&self) -> Option<u64> {
 830        self.model_info.max_output_tokens
 831    }
 832
 833    fn count_tokens(
 834        &self,
 835        request: LanguageModelRequest,
 836        cx: &App,
 837    ) -> BoxFuture<'static, Result<u64>> {
 838        let extension = self.extension.clone();
 839        let provider_id = self.provider_info.id.clone();
 840        let model_id = self.model_info.id.clone();
 841
 842        let wit_request = convert_request_to_wit(request);
 843
 844        cx.background_spawn(async move {
 845            extension
 846                .call({
 847                    let provider_id = provider_id.clone();
 848                    let model_id = model_id.clone();
 849                    let wit_request = wit_request.clone();
 850                    |ext, store| {
 851                        async move {
 852                            let count = ext
 853                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
 854                                .await?
 855                                .map_err(|e| anyhow!("{}", e))?;
 856                            Ok(count)
 857                        }
 858                        .boxed()
 859                    }
 860                })
 861                .await?
 862        })
 863        .boxed()
 864    }
 865
 866    fn stream_completion(
 867        &self,
 868        request: LanguageModelRequest,
 869        _cx: &AsyncApp,
 870    ) -> BoxFuture<
 871        'static,
 872        Result<
 873            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 874            LanguageModelCompletionError,
 875        >,
 876    > {
 877        let extension = self.extension.clone();
 878        let provider_id = self.provider_info.id.clone();
 879        let model_id = self.model_info.id.clone();
 880
 881        let wit_request = convert_request_to_wit(request);
 882
 883        async move {
 884            // Start the stream
 885            let stream_id_result = extension
 886                .call({
 887                    let provider_id = provider_id.clone();
 888                    let model_id = model_id.clone();
 889                    let wit_request = wit_request.clone();
 890                    |ext, store| {
 891                        async move {
 892                            let id = ext
 893                                .call_llm_stream_completion_start(
 894                                    store,
 895                                    &provider_id,
 896                                    &model_id,
 897                                    &wit_request,
 898                                )
 899                                .await?
 900                                .map_err(|e| anyhow!("{}", e))?;
 901                            Ok(id)
 902                        }
 903                        .boxed()
 904                    }
 905                })
 906                .await;
 907
 908            let stream_id = stream_id_result
 909                .map_err(LanguageModelCompletionError::Other)?
 910                .map_err(LanguageModelCompletionError::Other)?;
 911
 912            // Create a stream that polls for events
 913            let stream = futures::stream::unfold(
 914                (extension.clone(), stream_id, false),
 915                move |(extension, stream_id, done)| async move {
 916                    if done {
 917                        return None;
 918                    }
 919
 920                    let result = extension
 921                        .call({
 922                            let stream_id = stream_id.clone();
 923                            |ext, store| {
 924                                async move {
 925                                    let event = ext
 926                                        .call_llm_stream_completion_next(store, &stream_id)
 927                                        .await?
 928                                        .map_err(|e| anyhow!("{}", e))?;
 929                                    Ok(event)
 930                                }
 931                                .boxed()
 932                            }
 933                        })
 934                        .await
 935                        .and_then(|inner| inner);
 936
 937                    match result {
 938                        Ok(Some(event)) => {
 939                            let converted = convert_completion_event(event);
 940                            let is_done =
 941                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
 942                            Some((converted, (extension, stream_id, is_done)))
 943                        }
 944                        Ok(None) => {
 945                            // Stream complete, close it
 946                            let _ = extension
 947                                .call({
 948                                    let stream_id = stream_id.clone();
 949                                    |ext, store| {
 950                                        async move {
 951                                            ext.call_llm_stream_completion_close(store, &stream_id)
 952                                                .await?;
 953                                            Ok::<(), anyhow::Error>(())
 954                                        }
 955                                        .boxed()
 956                                    }
 957                                })
 958                                .await;
 959                            None
 960                        }
 961                        Err(e) => Some((
 962                            Err(LanguageModelCompletionError::Other(e)),
 963                            (extension, stream_id, true),
 964                        )),
 965                    }
 966                },
 967            );
 968
 969            Ok(stream.boxed())
 970        }
 971        .boxed()
 972    }
 973
 974    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 975        // Extensions can implement this via llm_cache_configuration
 976        None
 977    }
 978}
 979
 980fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
 981    use language_model::{MessageContent, Role};
 982
 983    let messages: Vec<LlmRequestMessage> = request
 984        .messages
 985        .into_iter()
 986        .map(|msg| {
 987            let role = match msg.role {
 988                Role::User => LlmMessageRole::User,
 989                Role::Assistant => LlmMessageRole::Assistant,
 990                Role::System => LlmMessageRole::System,
 991            };
 992
 993            let content: Vec<LlmMessageContent> = msg
 994                .content
 995                .into_iter()
 996                .map(|c| match c {
 997                    MessageContent::Text(text) => LlmMessageContent::Text(text),
 998                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
 999                        source: image.source.to_string(),
1000                        width: Some(image.size.width.0 as u32),
1001                        height: Some(image.size.height.0 as u32),
1002                    }),
1003                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1004                        id: tool_use.id.to_string(),
1005                        name: tool_use.name.to_string(),
1006                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1007                        thought_signature: tool_use.thought_signature,
1008                    }),
1009                    MessageContent::ToolResult(tool_result) => {
1010                        let content = match tool_result.content {
1011                            language_model::LanguageModelToolResultContent::Text(text) => {
1012                                LlmToolResultContent::Text(text.to_string())
1013                            }
1014                            language_model::LanguageModelToolResultContent::Image(image) => {
1015                                LlmToolResultContent::Image(LlmImageData {
1016                                    source: image.source.to_string(),
1017                                    width: Some(image.size.width.0 as u32),
1018                                    height: Some(image.size.height.0 as u32),
1019                                })
1020                            }
1021                        };
1022                        LlmMessageContent::ToolResult(LlmToolResult {
1023                            tool_use_id: tool_result.tool_use_id.to_string(),
1024                            tool_name: tool_result.tool_name.to_string(),
1025                            is_error: tool_result.is_error,
1026                            content,
1027                        })
1028                    }
1029                    MessageContent::Thinking { text, signature } => {
1030                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1031                    }
1032                    MessageContent::RedactedThinking(data) => {
1033                        LlmMessageContent::RedactedThinking(data)
1034                    }
1035                })
1036                .collect();
1037
1038            LlmRequestMessage {
1039                role,
1040                content,
1041                cache: msg.cache,
1042            }
1043        })
1044        .collect();
1045
1046    let tools: Vec<LlmToolDefinition> = request
1047        .tools
1048        .into_iter()
1049        .map(|tool| LlmToolDefinition {
1050            name: tool.name,
1051            description: tool.description,
1052            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1053        })
1054        .collect();
1055
1056    let tool_choice = request.tool_choice.map(|tc| match tc {
1057        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1058        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1059        LanguageModelToolChoice::None => LlmToolChoice::None,
1060    });
1061
1062    LlmCompletionRequest {
1063        messages,
1064        tools,
1065        tool_choice,
1066        stop_sequences: request.stop,
1067        temperature: request.temperature,
1068        thinking_allowed: false,
1069        max_tokens: None,
1070    }
1071}
1072
1073fn convert_completion_event(
1074    event: LlmCompletionEvent,
1075) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1076    match event {
1077        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1078            message_id: String::new(),
1079        }),
1080        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1081        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1082            text: thinking.text,
1083            signature: thinking.signature,
1084        }),
1085        LlmCompletionEvent::RedactedThinking(data) => {
1086            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1087        }
1088        LlmCompletionEvent::ToolUse(tool_use) => {
1089            let raw_input = tool_use.input.clone();
1090            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1091            Ok(LanguageModelCompletionEvent::ToolUse(
1092                LanguageModelToolUse {
1093                    id: LanguageModelToolUseId::from(tool_use.id),
1094                    name: tool_use.name.into(),
1095                    raw_input,
1096                    input,
1097                    is_input_complete: true,
1098                    thought_signature: tool_use.thought_signature,
1099                },
1100            ))
1101        }
1102        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1103            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1104                id: LanguageModelToolUseId::from(error.id),
1105                tool_name: error.tool_name.into(),
1106                raw_input: error.raw_input.into(),
1107                json_parse_error: error.error,
1108            })
1109        }
1110        LlmCompletionEvent::Stop(reason) => {
1111            let stop_reason = match reason {
1112                LlmStopReason::EndTurn => StopReason::EndTurn,
1113                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1114                LlmStopReason::ToolUse => StopReason::ToolUse,
1115                LlmStopReason::Refusal => StopReason::Refusal,
1116            };
1117            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1118        }
1119        LlmCompletionEvent::Usage(usage) => {
1120            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1121                input_tokens: usage.input_tokens,
1122                output_tokens: usage.output_tokens,
1123                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1124                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1125            }))
1126        }
1127        LlmCompletionEvent::ReasoningDetails(json) => {
1128            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1129                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1130            ))
1131        }
1132    }
1133}