llm_provider.rs

   1use crate::ExtensionSettings;
   2use crate::wasm_host::WasmExtension;
   3
   4use crate::wasm_host::wit::{
   5    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
   6    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
   7    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
   8    LlmToolUse,
   9};
  10use anyhow::{Result, anyhow};
  11use credentials_provider::CredentialsProvider;
  12use editor::Editor;
  13use extension::LanguageModelAuthConfig;
  14use futures::future::BoxFuture;
  15use futures::stream::BoxStream;
  16use futures::{FutureExt, StreamExt};
  17use gpui::Focusable;
  18use gpui::{
  19    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
  20    TextStyleRefinement, UnderlineStyle, Window, px,
  21};
  22use language_model::tool_schema::LanguageModelToolSchemaFormat;
  23use language_model::{
  24    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  25    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  27    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  28    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
  29};
  30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
  31use settings::Settings;
  32use std::sync::Arc;
  33use theme::ThemeSettings;
  34use ui::{Label, LabelSize, prelude::*};
  35use util::ResultExt as _;
  36
  37/// An extension-based language model provider.
  38pub struct ExtensionLanguageModelProvider {
  39    pub extension: WasmExtension,
  40    pub provider_info: LlmProviderInfo,
  41    icon_path: Option<SharedString>,
  42    auth_config: Option<LanguageModelAuthConfig>,
  43    state: Entity<ExtensionLlmProviderState>,
  44}
  45
  46pub struct ExtensionLlmProviderState {
  47    is_authenticated: bool,
  48    available_models: Vec<LlmModelInfo>,
  49    env_var_allowed: bool,
  50    api_key_from_env: bool,
  51}
  52
  53impl EventEmitter<()> for ExtensionLlmProviderState {}
  54
  55impl ExtensionLanguageModelProvider {
  56    pub fn new(
  57        extension: WasmExtension,
  58        provider_info: LlmProviderInfo,
  59        models: Vec<LlmModelInfo>,
  60        is_authenticated: bool,
  61        icon_path: Option<SharedString>,
  62        auth_config: Option<LanguageModelAuthConfig>,
  63        cx: &mut App,
  64    ) -> Self {
  65        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
  66        let env_var_allowed = ExtensionSettings::get_global(cx)
  67            .allowed_env_var_providers
  68            .contains(provider_id_string.as_str());
  69
  70        let (is_authenticated, api_key_from_env) =
  71            if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
  72                let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
  73                if let Ok(value) = std::env::var(env_var_name) {
  74                    if !value.is_empty() {
  75                        (true, true)
  76                    } else {
  77                        (is_authenticated, false)
  78                    }
  79                } else {
  80                    (is_authenticated, false)
  81                }
  82            } else {
  83                (is_authenticated, false)
  84            };
  85
  86        let state = cx.new(|_| ExtensionLlmProviderState {
  87            is_authenticated,
  88            available_models: models,
  89            env_var_allowed,
  90            api_key_from_env,
  91        });
  92
  93        Self {
  94            extension,
  95            provider_info,
  96            icon_path,
  97            auth_config,
  98            state,
  99        }
 100    }
 101
 102    fn provider_id_string(&self) -> String {
 103        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 104    }
 105
 106    /// The credential key used for storing the API key in the system keychain.
 107    fn credential_key(&self) -> String {
 108        format!("extension-llm-{}", self.provider_id_string())
 109    }
 110}
 111
 112impl LanguageModelProvider for ExtensionLanguageModelProvider {
 113    fn id(&self) -> LanguageModelProviderId {
 114        LanguageModelProviderId::from(self.provider_id_string())
 115    }
 116
 117    fn name(&self) -> LanguageModelProviderName {
 118        LanguageModelProviderName::from(self.provider_info.name.clone())
 119    }
 120
 121    fn icon(&self) -> ui::IconName {
 122        ui::IconName::ZedAssistant
 123    }
 124
 125    fn icon_path(&self) -> Option<SharedString> {
 126        self.icon_path.clone()
 127    }
 128
 129    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 130        let state = self.state.read(cx);
 131        state
 132            .available_models
 133            .iter()
 134            .find(|m| m.is_default)
 135            .or_else(|| state.available_models.first())
 136            .map(|model_info| {
 137                Arc::new(ExtensionLanguageModel {
 138                    extension: self.extension.clone(),
 139                    model_info: model_info.clone(),
 140                    provider_id: self.id(),
 141                    provider_name: self.name(),
 142                    provider_info: self.provider_info.clone(),
 143                }) as Arc<dyn LanguageModel>
 144            })
 145    }
 146
 147    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 148        let state = self.state.read(cx);
 149        state
 150            .available_models
 151            .iter()
 152            .find(|m| m.is_default_fast)
 153            .map(|model_info| {
 154                Arc::new(ExtensionLanguageModel {
 155                    extension: self.extension.clone(),
 156                    model_info: model_info.clone(),
 157                    provider_id: self.id(),
 158                    provider_name: self.name(),
 159                    provider_info: self.provider_info.clone(),
 160                }) as Arc<dyn LanguageModel>
 161            })
 162    }
 163
 164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 165        let state = self.state.read(cx);
 166        state
 167            .available_models
 168            .iter()
 169            .map(|model_info| {
 170                Arc::new(ExtensionLanguageModel {
 171                    extension: self.extension.clone(),
 172                    model_info: model_info.clone(),
 173                    provider_id: self.id(),
 174                    provider_name: self.name(),
 175                    provider_info: self.provider_info.clone(),
 176                }) as Arc<dyn LanguageModel>
 177            })
 178            .collect()
 179    }
 180
 181    fn is_authenticated(&self, cx: &App) -> bool {
 182        self.state.read(cx).is_authenticated
 183    }
 184
 185    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 186        let extension = self.extension.clone();
 187        let provider_id = self.provider_info.id.clone();
 188        let state = self.state.clone();
 189
 190        cx.spawn(async move |cx| {
 191            let result = extension
 192                .call(|extension, store| {
 193                    async move {
 194                        extension
 195                            .call_llm_provider_authenticate(store, &provider_id)
 196                            .await
 197                    }
 198                    .boxed()
 199                })
 200                .await;
 201
 202            match result {
 203                Ok(Ok(Ok(()))) => {
 204                    cx.update(|cx| {
 205                        state.update(cx, |state, _| {
 206                            state.is_authenticated = true;
 207                        });
 208                    })?;
 209                    Ok(())
 210                }
 211                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
 212                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
 213                Err(e) => Err(AuthenticateError::Other(e)),
 214            }
 215        })
 216    }
 217
 218    fn configuration_view(
 219        &self,
 220        _target_agent: ConfigurationViewTargetAgent,
 221        window: &mut Window,
 222        cx: &mut App,
 223    ) -> AnyView {
 224        let credential_key = self.credential_key();
 225        let extension = self.extension.clone();
 226        let extension_provider_id = self.provider_info.id.clone();
 227        let full_provider_id = self.provider_id_string();
 228        let state = self.state.clone();
 229        let auth_config = self.auth_config.clone();
 230
 231        cx.new(|cx| {
 232            ExtensionProviderConfigurationView::new(
 233                credential_key,
 234                extension,
 235                extension_provider_id,
 236                full_provider_id,
 237                auth_config,
 238                state,
 239                window,
 240                cx,
 241            )
 242        })
 243        .into()
 244    }
 245
 246    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 247        let extension = self.extension.clone();
 248        let provider_id = self.provider_info.id.clone();
 249        let state = self.state.clone();
 250        let credential_key = self.credential_key();
 251
 252        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 253
 254        cx.spawn(async move |cx| {
 255            // Delete from system keychain
 256            credentials_provider
 257                .delete_credentials(&credential_key, cx)
 258                .await
 259                .log_err();
 260
 261            // Call extension's reset_credentials
 262            let result = extension
 263                .call(|extension, store| {
 264                    async move {
 265                        extension
 266                            .call_llm_provider_reset_credentials(store, &provider_id)
 267                            .await
 268                    }
 269                    .boxed()
 270                })
 271                .await;
 272
 273            // Update state
 274            cx.update(|cx| {
 275                state.update(cx, |state, _| {
 276                    state.is_authenticated = false;
 277                });
 278            })?;
 279
 280            match result {
 281                Ok(Ok(Ok(()))) => Ok(()),
 282                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
 283                Ok(Err(e)) => Err(e),
 284                Err(e) => Err(e),
 285            }
 286        })
 287    }
 288}
 289
 290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
 291    type ObservableEntity = ExtensionLlmProviderState;
 292
 293    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 294        Some(self.state.clone())
 295    }
 296
 297    fn subscribe<T: 'static>(
 298        &self,
 299        cx: &mut Context<T>,
 300        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
 301    ) -> Option<Subscription> {
 302        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
 303    }
 304}
 305
 306/// Configuration view for extension-based LLM providers.
 307struct ExtensionProviderConfigurationView {
 308    credential_key: String,
 309    extension: WasmExtension,
 310    extension_provider_id: String,
 311    full_provider_id: String,
 312    auth_config: Option<LanguageModelAuthConfig>,
 313    state: Entity<ExtensionLlmProviderState>,
 314    settings_markdown: Option<Entity<Markdown>>,
 315    api_key_editor: Entity<Editor>,
 316    loading_settings: bool,
 317    loading_credentials: bool,
 318    _subscriptions: Vec<Subscription>,
 319}
 320
 321impl ExtensionProviderConfigurationView {
 322    fn new(
 323        credential_key: String,
 324        extension: WasmExtension,
 325        extension_provider_id: String,
 326        full_provider_id: String,
 327        auth_config: Option<LanguageModelAuthConfig>,
 328        state: Entity<ExtensionLlmProviderState>,
 329        window: &mut Window,
 330        cx: &mut Context<Self>,
 331    ) -> Self {
 332        // Subscribe to state changes
 333        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
 334            cx.notify();
 335        });
 336
 337        // Create API key editor
 338        let api_key_editor = cx.new(|cx| {
 339            let mut editor = Editor::single_line(window, cx);
 340            editor.set_placeholder_text("Enter API key...", window, cx);
 341            editor
 342        });
 343
 344        let mut this = Self {
 345            credential_key,
 346            extension,
 347            extension_provider_id,
 348            full_provider_id,
 349            auth_config,
 350            state,
 351            settings_markdown: None,
 352            api_key_editor,
 353            loading_settings: true,
 354            loading_credentials: true,
 355            _subscriptions: vec![state_subscription],
 356        };
 357
 358        // Load settings text from extension
 359        this.load_settings_text(cx);
 360
 361        // Load existing credentials
 362        this.load_credentials(cx);
 363
 364        this
 365    }
 366
 367    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
 368        let extension = self.extension.clone();
 369        let provider_id = self.extension_provider_id.clone();
 370
 371        cx.spawn(async move |this, cx| {
 372            let result = extension
 373                .call({
 374                    let provider_id = provider_id.clone();
 375                    |ext, store| {
 376                        async move {
 377                            ext.call_llm_provider_settings_markdown(store, &provider_id)
 378                                .await
 379                        }
 380                        .boxed()
 381                    }
 382                })
 383                .await;
 384
 385            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
 386
 387            this.update(cx, |this, cx| {
 388                this.loading_settings = false;
 389                if let Some(text) = settings_text {
 390                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
 391                    this.settings_markdown = Some(markdown);
 392                }
 393                cx.notify();
 394            })
 395            .log_err();
 396        })
 397        .detach();
 398    }
 399
 400    fn load_credentials(&mut self, cx: &mut Context<Self>) {
 401        let credential_key = self.credential_key.clone();
 402        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 403        let state = self.state.clone();
 404
 405        // Check if we should use env var (already set in state during provider construction)
 406        let api_key_from_env = self.state.read(cx).api_key_from_env;
 407
 408        cx.spawn(async move |this, cx| {
 409            // If using env var, we're already authenticated
 410            if api_key_from_env {
 411                this.update(cx, |this, cx| {
 412                    this.loading_credentials = false;
 413                    cx.notify();
 414                })
 415                .log_err();
 416                return;
 417            }
 418
 419            let credentials = credentials_provider
 420                .read_credentials(&credential_key, cx)
 421                .await
 422                .log_err()
 423                .flatten();
 424
 425            let has_credentials = credentials.is_some();
 426
 427            // Update authentication state based on stored credentials
 428            let _ = cx.update(|cx| {
 429                state.update(cx, |state, cx| {
 430                    state.is_authenticated = has_credentials;
 431                    cx.notify();
 432                });
 433            });
 434
 435            this.update(cx, |this, cx| {
 436                this.loading_credentials = false;
 437                cx.notify();
 438            })
 439            .log_err();
 440        })
 441        .detach();
 442    }
 443
 444    fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
 445        let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
 446        let env_var_name = match &self.auth_config {
 447            Some(config) => config.env_var.clone(),
 448            None => return,
 449        };
 450
 451        let state = self.state.clone();
 452        let currently_allowed = self.state.read(cx).env_var_allowed;
 453
 454        // Update settings file
 455        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
 456            let providers = settings
 457                .extension
 458                .allowed_env_var_providers
 459                .get_or_insert_with(Vec::new);
 460
 461            if currently_allowed {
 462                providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
 463            } else {
 464                if !providers
 465                    .iter()
 466                    .any(|id| id.as_ref() == full_provider_id.as_ref())
 467                {
 468                    providers.push(full_provider_id.clone());
 469                }
 470            }
 471        });
 472
 473        // Update local state
 474        let new_allowed = !currently_allowed;
 475        let new_from_env = if new_allowed {
 476            if let Some(var_name) = &env_var_name {
 477                if let Ok(value) = std::env::var(var_name) {
 478                    !value.is_empty()
 479                } else {
 480                    false
 481                }
 482            } else {
 483                false
 484            }
 485        } else {
 486            false
 487        };
 488
 489        state.update(cx, |state, cx| {
 490            state.env_var_allowed = new_allowed;
 491            state.api_key_from_env = new_from_env;
 492            if new_from_env {
 493                state.is_authenticated = true;
 494            }
 495            cx.notify();
 496        });
 497
 498        // If env var is being enabled, clear any stored keychain credentials
 499        // so there's only one source of truth for the API key
 500        if new_allowed {
 501            let credential_key = self.credential_key.clone();
 502            let credentials_provider = <dyn CredentialsProvider>::global(cx);
 503            cx.spawn(async move |_this, cx| {
 504                credentials_provider
 505                    .delete_credentials(&credential_key, cx)
 506                    .await
 507                    .log_err();
 508            })
 509            .detach();
 510        }
 511
 512        // If env var is being disabled, reload credentials from keychain
 513        if !new_allowed {
 514            self.reload_keychain_credentials(cx);
 515        }
 516
 517        cx.notify();
 518    }
 519
 520    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
 521        let credential_key = self.credential_key.clone();
 522        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 523        let state = self.state.clone();
 524
 525        cx.spawn(async move |_this, cx| {
 526            let credentials = credentials_provider
 527                .read_credentials(&credential_key, cx)
 528                .await
 529                .log_err()
 530                .flatten();
 531
 532            let has_credentials = credentials.is_some();
 533
 534            let _ = cx.update(|cx| {
 535                state.update(cx, |state, cx| {
 536                    state.is_authenticated = has_credentials;
 537                    cx.notify();
 538                });
 539            });
 540        })
 541        .detach();
 542    }
 543
 544    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 545        let api_key = self.api_key_editor.read(cx).text(cx);
 546        if api_key.is_empty() {
 547            return;
 548        }
 549
 550        // Clear the editor
 551        self.api_key_editor
 552            .update(cx, |editor, cx| editor.set_text("", window, cx));
 553
 554        let credential_key = self.credential_key.clone();
 555        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 556        let state = self.state.clone();
 557
 558        cx.spawn(async move |_this, cx| {
 559            // Store in system keychain
 560            credentials_provider
 561                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
 562                .await
 563                .log_err();
 564
 565            // Update state to authenticated
 566            let _ = cx.update(|cx| {
 567                state.update(cx, |state, cx| {
 568                    state.is_authenticated = true;
 569                    cx.notify();
 570                });
 571            });
 572        })
 573        .detach();
 574    }
 575
 576    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 577        // Clear the editor
 578        self.api_key_editor
 579            .update(cx, |editor, cx| editor.set_text("", window, cx));
 580
 581        let credential_key = self.credential_key.clone();
 582        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 583        let state = self.state.clone();
 584
 585        cx.spawn(async move |_this, cx| {
 586            // Delete from system keychain
 587            credentials_provider
 588                .delete_credentials(&credential_key, cx)
 589                .await
 590                .log_err();
 591
 592            // Update state to unauthenticated
 593            let _ = cx.update(|cx| {
 594                state.update(cx, |state, cx| {
 595                    state.is_authenticated = false;
 596                    cx.notify();
 597                });
 598            });
 599        })
 600        .detach();
 601    }
 602
 603    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
 604        self.state.read(cx).is_authenticated
 605    }
 606}
 607
 608impl gpui::Render for ExtensionProviderConfigurationView {
 609    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 610        let is_loading = self.loading_settings || self.loading_credentials;
 611        let is_authenticated = self.is_authenticated(cx);
 612        let env_var_allowed = self.state.read(cx).env_var_allowed;
 613        let api_key_from_env = self.state.read(cx).api_key_from_env;
 614
 615        if is_loading {
 616            return v_flex()
 617                .gap_2()
 618                .child(Label::new("Loading...").color(Color::Muted))
 619                .into_any_element();
 620        }
 621
 622        let mut content = v_flex().gap_4().size_full();
 623
 624        // Render settings markdown if available
 625        if let Some(markdown) = &self.settings_markdown {
 626            let style = settings_markdown_style(_window, cx);
 627            content = content.child(
 628                div()
 629                    .p_2()
 630                    .rounded_md()
 631                    .bg(cx.theme().colors().surface_background)
 632                    .child(MarkdownElement::new(markdown.clone(), style)),
 633            );
 634        }
 635
 636        // Render env var checkbox if the extension specifies an env var
 637        if let Some(auth_config) = &self.auth_config {
 638            if let Some(env_var_name) = &auth_config.env_var {
 639                let env_var_name = env_var_name.clone();
 640                let checkbox_label =
 641                    format!("Read API key from {} environment variable", env_var_name);
 642
 643                content = content.child(
 644                    h_flex()
 645                        .gap_2()
 646                        .child(
 647                            ui::Checkbox::new("env-var-permission", env_var_allowed.into())
 648                                .on_click(cx.listener(|this, _, _window, cx| {
 649                                    this.toggle_env_var_permission(cx);
 650                                })),
 651                        )
 652                        .child(Label::new(checkbox_label).size(LabelSize::Small)),
 653                );
 654
 655                // Show status if env var is allowed
 656                if env_var_allowed {
 657                    if api_key_from_env {
 658                        content = content.child(
 659                            h_flex()
 660                                .gap_2()
 661                                .child(
 662                                    ui::Icon::new(ui::IconName::Check)
 663                                        .color(Color::Success)
 664                                        .size(ui::IconSize::Small),
 665                                )
 666                                .child(
 667                                    Label::new(format!("API key loaded from {}", env_var_name))
 668                                        .color(Color::Success),
 669                                ),
 670                        );
 671                        return content.into_any_element();
 672                    } else {
 673                        content = content.child(
 674                            h_flex()
 675                                .gap_2()
 676                                .child(
 677                                    ui::Icon::new(ui::IconName::Warning)
 678                                        .color(Color::Warning)
 679                                        .size(ui::IconSize::Small),
 680                                )
 681                                .child(
 682                                    Label::new(format!(
 683                                        "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
 684                                        env_var_name
 685                                    ))
 686                                    .color(Color::Warning)
 687                                    .size(LabelSize::Small),
 688                                ),
 689                        );
 690                    }
 691                }
 692            }
 693        }
 694
 695        // Render API key section
 696        if is_authenticated && !api_key_from_env {
 697            content = content.child(
 698                v_flex()
 699                    .gap_2()
 700                    .child(
 701                        h_flex()
 702                            .gap_2()
 703                            .child(
 704                                ui::Icon::new(ui::IconName::Check)
 705                                    .color(Color::Success)
 706                                    .size(ui::IconSize::Small),
 707                            )
 708                            .child(Label::new("API key configured").color(Color::Success)),
 709                    )
 710                    .child(
 711                        ui::Button::new("reset-api-key", "Reset API Key")
 712                            .style(ui::ButtonStyle::Subtle)
 713                            .on_click(cx.listener(|this, _, window, cx| {
 714                                this.reset_api_key(window, cx);
 715                            })),
 716                    ),
 717            );
 718        } else if !api_key_from_env {
 719            let credential_label = self
 720                .auth_config
 721                .as_ref()
 722                .and_then(|c| c.credential_label.clone())
 723                .unwrap_or_else(|| "API Key".to_string());
 724
 725            content = content.child(
 726                v_flex()
 727                    .gap_2()
 728                    .on_action(cx.listener(Self::save_api_key))
 729                    .child(
 730                        Label::new(credential_label)
 731                            .size(LabelSize::Small)
 732                            .color(Color::Muted),
 733                    )
 734                    .child(self.api_key_editor.clone())
 735                    .child(
 736                        Label::new("Enter your API key and press Enter to save")
 737                            .size(LabelSize::Small)
 738                            .color(Color::Muted),
 739                    ),
 740            );
 741        }
 742
 743        content.into_any_element()
 744    }
 745}
 746
 747impl Focusable for ExtensionProviderConfigurationView {
 748    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 749        self.api_key_editor.focus_handle(cx)
 750    }
 751}
 752
 753fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
 754    let theme_settings = ThemeSettings::get_global(cx);
 755    let colors = cx.theme().colors();
 756    let mut text_style = window.text_style();
 757    text_style.refine(&TextStyleRefinement {
 758        font_family: Some(theme_settings.ui_font.family.clone()),
 759        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
 760        font_features: Some(theme_settings.ui_font.features.clone()),
 761        color: Some(colors.text),
 762        ..Default::default()
 763    });
 764
 765    MarkdownStyle {
 766        base_text_style: text_style,
 767        selection_background_color: colors.element_selection_background,
 768        inline_code: TextStyleRefinement {
 769            background_color: Some(colors.editor_background),
 770            ..Default::default()
 771        },
 772        link: TextStyleRefinement {
 773            color: Some(colors.text_accent),
 774            underline: Some(UnderlineStyle {
 775                color: Some(colors.text_accent.opacity(0.5)),
 776                thickness: px(1.),
 777                ..Default::default()
 778            }),
 779            ..Default::default()
 780        },
 781        syntax: cx.theme().syntax().clone(),
 782        ..Default::default()
 783    }
 784}
 785
 786/// An extension-based language model.
 787pub struct ExtensionLanguageModel {
 788    extension: WasmExtension,
 789    model_info: LlmModelInfo,
 790    provider_id: LanguageModelProviderId,
 791    provider_name: LanguageModelProviderName,
 792    provider_info: LlmProviderInfo,
 793}
 794
 795impl LanguageModel for ExtensionLanguageModel {
 796    fn id(&self) -> LanguageModelId {
 797        LanguageModelId::from(self.model_info.id.clone())
 798    }
 799
 800    fn name(&self) -> LanguageModelName {
 801        LanguageModelName::from(self.model_info.name.clone())
 802    }
 803
 804    fn provider_id(&self) -> LanguageModelProviderId {
 805        self.provider_id.clone()
 806    }
 807
 808    fn provider_name(&self) -> LanguageModelProviderName {
 809        self.provider_name.clone()
 810    }
 811
 812    fn telemetry_id(&self) -> String {
 813        format!("extension-{}", self.model_info.id)
 814    }
 815
 816    fn supports_images(&self) -> bool {
 817        self.model_info.capabilities.supports_images
 818    }
 819
 820    fn supports_tools(&self) -> bool {
 821        self.model_info.capabilities.supports_tools
 822    }
 823
 824    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 825        match choice {
 826            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
 827            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
 828            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
 829        }
 830    }
 831
 832    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 833        match self.model_info.capabilities.tool_input_format {
 834            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
 835            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
 836        }
 837    }
 838
 839    fn max_token_count(&self) -> u64 {
 840        self.model_info.max_token_count
 841    }
 842
 843    fn max_output_tokens(&self) -> Option<u64> {
 844        self.model_info.max_output_tokens
 845    }
 846
 847    fn count_tokens(
 848        &self,
 849        request: LanguageModelRequest,
 850        cx: &App,
 851    ) -> BoxFuture<'static, Result<u64>> {
 852        let extension = self.extension.clone();
 853        let provider_id = self.provider_info.id.clone();
 854        let model_id = self.model_info.id.clone();
 855
 856        let wit_request = convert_request_to_wit(request);
 857
 858        cx.background_spawn(async move {
 859            extension
 860                .call({
 861                    let provider_id = provider_id.clone();
 862                    let model_id = model_id.clone();
 863                    let wit_request = wit_request.clone();
 864                    |ext, store| {
 865                        async move {
 866                            let count = ext
 867                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
 868                                .await?
 869                                .map_err(|e| anyhow!("{}", e))?;
 870                            Ok(count)
 871                        }
 872                        .boxed()
 873                    }
 874                })
 875                .await?
 876        })
 877        .boxed()
 878    }
 879
 880    fn stream_completion(
 881        &self,
 882        request: LanguageModelRequest,
 883        _cx: &AsyncApp,
 884    ) -> BoxFuture<
 885        'static,
 886        Result<
 887            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 888            LanguageModelCompletionError,
 889        >,
 890    > {
 891        let extension = self.extension.clone();
 892        let provider_id = self.provider_info.id.clone();
 893        let model_id = self.model_info.id.clone();
 894
 895        let wit_request = convert_request_to_wit(request);
 896
 897        async move {
 898            // Start the stream
 899            let stream_id_result = extension
 900                .call({
 901                    let provider_id = provider_id.clone();
 902                    let model_id = model_id.clone();
 903                    let wit_request = wit_request.clone();
 904                    |ext, store| {
 905                        async move {
 906                            let id = ext
 907                                .call_llm_stream_completion_start(
 908                                    store,
 909                                    &provider_id,
 910                                    &model_id,
 911                                    &wit_request,
 912                                )
 913                                .await?
 914                                .map_err(|e| anyhow!("{}", e))?;
 915                            Ok(id)
 916                        }
 917                        .boxed()
 918                    }
 919                })
 920                .await;
 921
 922            let stream_id = stream_id_result
 923                .map_err(LanguageModelCompletionError::Other)?
 924                .map_err(LanguageModelCompletionError::Other)?;
 925
 926            // Create a stream that polls for events
 927            let stream = futures::stream::unfold(
 928                (extension.clone(), stream_id, false),
 929                move |(extension, stream_id, done)| async move {
 930                    if done {
 931                        return None;
 932                    }
 933
 934                    let result = extension
 935                        .call({
 936                            let stream_id = stream_id.clone();
 937                            |ext, store| {
 938                                async move {
 939                                    let event = ext
 940                                        .call_llm_stream_completion_next(store, &stream_id)
 941                                        .await?
 942                                        .map_err(|e| anyhow!("{}", e))?;
 943                                    Ok(event)
 944                                }
 945                                .boxed()
 946                            }
 947                        })
 948                        .await
 949                        .and_then(|inner| inner);
 950
 951                    match result {
 952                        Ok(Some(event)) => {
 953                            let converted = convert_completion_event(event);
 954                            let is_done =
 955                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
 956                            Some((converted, (extension, stream_id, is_done)))
 957                        }
 958                        Ok(None) => {
 959                            // Stream complete, close it
 960                            let _ = extension
 961                                .call({
 962                                    let stream_id = stream_id.clone();
 963                                    |ext, store| {
 964                                        async move {
 965                                            ext.call_llm_stream_completion_close(store, &stream_id)
 966                                                .await?;
 967                                            Ok::<(), anyhow::Error>(())
 968                                        }
 969                                        .boxed()
 970                                    }
 971                                })
 972                                .await;
 973                            None
 974                        }
 975                        Err(e) => Some((
 976                            Err(LanguageModelCompletionError::Other(e)),
 977                            (extension, stream_id, true),
 978                        )),
 979                    }
 980                },
 981            );
 982
 983            Ok(stream.boxed())
 984        }
 985        .boxed()
 986    }
 987
 988    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 989        // Extensions can implement this via llm_cache_configuration
 990        None
 991    }
 992}
 993
 994fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
 995    use language_model::{MessageContent, Role};
 996
 997    let messages: Vec<LlmRequestMessage> = request
 998        .messages
 999        .into_iter()
1000        .map(|msg| {
1001            let role = match msg.role {
1002                Role::User => LlmMessageRole::User,
1003                Role::Assistant => LlmMessageRole::Assistant,
1004                Role::System => LlmMessageRole::System,
1005            };
1006
1007            let content: Vec<LlmMessageContent> = msg
1008                .content
1009                .into_iter()
1010                .map(|c| match c {
1011                    MessageContent::Text(text) => LlmMessageContent::Text(text),
1012                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1013                        source: image.source.to_string(),
1014                        width: Some(image.size.width.0 as u32),
1015                        height: Some(image.size.height.0 as u32),
1016                    }),
1017                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1018                        id: tool_use.id.to_string(),
1019                        name: tool_use.name.to_string(),
1020                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1021                        thought_signature: tool_use.thought_signature,
1022                    }),
1023                    MessageContent::ToolResult(tool_result) => {
1024                        let content = match tool_result.content {
1025                            language_model::LanguageModelToolResultContent::Text(text) => {
1026                                LlmToolResultContent::Text(text.to_string())
1027                            }
1028                            language_model::LanguageModelToolResultContent::Image(image) => {
1029                                LlmToolResultContent::Image(LlmImageData {
1030                                    source: image.source.to_string(),
1031                                    width: Some(image.size.width.0 as u32),
1032                                    height: Some(image.size.height.0 as u32),
1033                                })
1034                            }
1035                        };
1036                        LlmMessageContent::ToolResult(LlmToolResult {
1037                            tool_use_id: tool_result.tool_use_id.to_string(),
1038                            tool_name: tool_result.tool_name.to_string(),
1039                            is_error: tool_result.is_error,
1040                            content,
1041                        })
1042                    }
1043                    MessageContent::Thinking { text, signature } => {
1044                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1045                    }
1046                    MessageContent::RedactedThinking(data) => {
1047                        LlmMessageContent::RedactedThinking(data)
1048                    }
1049                })
1050                .collect();
1051
1052            LlmRequestMessage {
1053                role,
1054                content,
1055                cache: msg.cache,
1056            }
1057        })
1058        .collect();
1059
1060    let tools: Vec<LlmToolDefinition> = request
1061        .tools
1062        .into_iter()
1063        .map(|tool| LlmToolDefinition {
1064            name: tool.name,
1065            description: tool.description,
1066            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1067        })
1068        .collect();
1069
1070    let tool_choice = request.tool_choice.map(|tc| match tc {
1071        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1072        LanguageModelToolChoice::Any => LlmToolChoice::Any,
1073        LanguageModelToolChoice::None => LlmToolChoice::None,
1074    });
1075
1076    LlmCompletionRequest {
1077        messages,
1078        tools,
1079        tool_choice,
1080        stop_sequences: request.stop,
1081        temperature: request.temperature,
1082        thinking_allowed: false,
1083        max_tokens: None,
1084    }
1085}
1086
1087fn convert_completion_event(
1088    event: LlmCompletionEvent,
1089) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1090    match event {
1091        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1092            message_id: String::new(),
1093        }),
1094        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1095        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1096            text: thinking.text,
1097            signature: thinking.signature,
1098        }),
1099        LlmCompletionEvent::RedactedThinking(data) => {
1100            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1101        }
1102        LlmCompletionEvent::ToolUse(tool_use) => {
1103            let raw_input = tool_use.input.clone();
1104            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1105            Ok(LanguageModelCompletionEvent::ToolUse(
1106                LanguageModelToolUse {
1107                    id: LanguageModelToolUseId::from(tool_use.id),
1108                    name: tool_use.name.into(),
1109                    raw_input,
1110                    input,
1111                    is_input_complete: true,
1112                    thought_signature: tool_use.thought_signature,
1113                },
1114            ))
1115        }
1116        LlmCompletionEvent::ToolUseJsonParseError(error) => {
1117            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1118                id: LanguageModelToolUseId::from(error.id),
1119                tool_name: error.tool_name.into(),
1120                raw_input: error.raw_input.into(),
1121                json_parse_error: error.error,
1122            })
1123        }
1124        LlmCompletionEvent::Stop(reason) => {
1125            let stop_reason = match reason {
1126                LlmStopReason::EndTurn => StopReason::EndTurn,
1127                LlmStopReason::MaxTokens => StopReason::MaxTokens,
1128                LlmStopReason::ToolUse => StopReason::ToolUse,
1129                LlmStopReason::Refusal => StopReason::Refusal,
1130            };
1131            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1132        }
1133        LlmCompletionEvent::Usage(usage) => {
1134            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1135                input_tokens: usage.input_tokens,
1136                output_tokens: usage.output_tokens,
1137                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1138                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1139            }))
1140        }
1141        LlmCompletionEvent::ReasoningDetails(json) => {
1142            Ok(LanguageModelCompletionEvent::ReasoningDetails(
1143                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1144            ))
1145        }
1146    }
1147}