Add provider extension API key in settings

Richard Feldman created

Change summary

Cargo.lock                                          |   4 
crates/extension_api/src/extension_api.rs           |  10 
crates/extension_api/wit/since_v0.7.0/extension.wit |   4 
crates/extension_host/Cargo.toml                    |   5 
crates/extension_host/src/wasm_host/llm_provider.rs | 757 ++++++++++----
crates/extension_host/src/wasm_host/wit.rs          |  14 
6 files changed, 558 insertions(+), 236 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5905,9 +5905,11 @@ dependencies = [
  "async-trait",
  "client",
  "collections",
+ "credentials_provider",
  "criterion",
  "ctor",
  "dap",
+ "editor",
  "extension",
  "fs",
  "futures 0.3.31",
@@ -5919,6 +5921,8 @@ dependencies = [
  "language_model",
  "log",
  "lsp",
+ "markdown",
+ "menu",
  "moka",
  "node_runtime",
  "parking_lot",

crates/extension_api/src/extension_api.rs 🔗

@@ -288,6 +288,12 @@ pub trait Extension: Send + Sync {
         Ok(Vec::new())
     }
 
+    /// Returns markdown content to display in the provider's settings UI.
+    /// This can include setup instructions, links to documentation, etc.
+    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
+        None
+    }
+
     /// Check if the provider is authenticated.
     fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
         false
@@ -618,6 +624,10 @@ impl wit::Guest for Component {
         extension().llm_provider_models(&provider_id)
     }
 
+    fn llm_provider_settings_markdown(provider_id: String) -> Option<String> {
+        extension().llm_provider_settings_markdown(&provider_id)
+    }
+
     fn llm_provider_is_authenticated(provider_id: String) -> bool {
         extension().llm_provider_is_authenticated(&provider_id)
     }

crates/extension_api/wit/since_v0.7.0/extension.wit 🔗

@@ -180,6 +180,10 @@ world extension {
     /// Returns the models available for a provider.
     export llm-provider-models: func(provider-id: string) -> result<list<model-info>, string>;
 
+    /// Returns markdown content to display in the provider's settings UI.
+    /// This can include setup instructions, links to documentation, etc.
+    export llm-provider-settings-markdown: func(provider-id: string) -> option<string>;
+
     /// Check if the provider is authenticated.
     export llm-provider-is-authenticated: func(provider-id: string) -> bool;
 

crates/extension_host/Cargo.toml 🔗

@@ -22,7 +22,9 @@ async-tar.workspace = true
 async-trait.workspace = true
 client.workspace = true
 collections.workspace = true
+credentials_provider.workspace = true
 dap.workspace = true
+editor.workspace = true
 extension.workspace = true
 fs.workspace = true
 futures.workspace = true
@@ -32,7 +34,9 @@ http_client.workspace = true
 language.workspace = true
 language_model.workspace = true
 log.workspace = true
+markdown.workspace = true
 lsp.workspace = true
+menu.workspace = true
 moka.workspace = true
 node_runtime.workspace = true
 paths.workspace = true
@@ -47,6 +51,7 @@ settings.workspace = true
 task.workspace = true
 telemetry.workspace = true
 tempfile.workspace = true
+theme.workspace = true
 toml.workspace = true
 ui.workspace = true
 url.workspace = true

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -7,10 +7,16 @@ use crate::wasm_host::wit::{
     LlmToolUse,
 };
 use anyhow::{Result, anyhow};
+use credentials_provider::CredentialsProvider;
+use editor::Editor;
 use futures::future::BoxFuture;
 use futures::stream::BoxStream;
 use futures::{FutureExt, StreamExt};
-use gpui::{AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, Window};
+use gpui::Focusable;
+use gpui::{
+    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
+    TextStyleRefinement, UnderlineStyle, Window, px,
+};
 use language_model::tool_schema::LanguageModelToolSchemaFormat;
 use language_model::{
     AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
@@ -19,7 +25,12 @@ use language_model::{
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
 };
+use markdown::{Markdown, MarkdownElement, MarkdownStyle};
+use settings::Settings;
 use std::sync::Arc;
+use theme::ThemeSettings;
+use ui::{Label, LabelSize, prelude::*};
+use util::ResultExt as _;
 
 /// An extension-based language model provider.
 pub struct ExtensionLanguageModelProvider {
@@ -58,13 +69,16 @@ impl ExtensionLanguageModelProvider {
     fn provider_id_string(&self) -> String {
         format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
     }
+
+    /// The credential key used for storing the API key in the system keychain.
+    fn credential_key(&self) -> String {
+        format!("extension-llm-{}", self.provider_id_string())
+    }
 }
 
 impl LanguageModelProvider for ExtensionLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        let id = LanguageModelProviderId::from(self.provider_id_string());
-        eprintln!("ExtensionLanguageModelProvider::id() -> {:?}", id);
-        id
+        LanguageModelProviderId::from(self.provider_id_string())
     }
 
     fn name(&self) -> LanguageModelProviderName {
@@ -99,8 +113,6 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
             .available_models
             .iter()
             .find(|m| m.is_default_fast)
-            .or_else(|| state.available_models.iter().find(|m| m.is_default))
-            .or_else(|| state.available_models.first())
             .map(|model_info| {
                 Arc::new(ExtensionLanguageModel {
                     extension: self.extension.clone(),
@@ -114,16 +126,10 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
 
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
         let state = self.state.read(cx);
-        eprintln!(
-            "ExtensionLanguageModelProvider::provided_models called for {}, returning {} models",
-            self.provider_info.name,
-            state.available_models.len()
-        );
         state
             .available_models
             .iter()
             .map(|model_info| {
-                eprintln!("  - model: {}", model_info.name);
                 Arc::new(ExtensionLanguageModel {
                     extension: self.extension.clone(),
                     model_info: model_info.clone(),
@@ -175,18 +181,43 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
     fn configuration_view(
         &self,
         _target_agent: ConfigurationViewTargetAgent,
-        _window: &mut Window,
+        window: &mut Window,
         cx: &mut App,
     ) -> AnyView {
-        cx.new(|_| EmptyConfigView).into()
+        let credential_key = self.credential_key();
+        let extension = self.extension.clone();
+        let extension_provider_id = self.provider_info.id.clone();
+        let state = self.state.clone();
+
+        cx.new(|cx| {
+            ExtensionProviderConfigurationView::new(
+                credential_key,
+                extension,
+                extension_provider_id,
+                state,
+                window,
+                cx,
+            )
+        })
+        .into()
     }
 
     fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
         let extension = self.extension.clone();
         let provider_id = self.provider_info.id.clone();
         let state = self.state.clone();
+        let credential_key = self.credential_key();
+
+        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 
         cx.spawn(async move |cx| {
+            // Delete from system keychain
+            credentials_provider
+                .delete_credentials(&credential_key, cx)
+                .await
+                .log_err();
+
+            // Call extension's reset_credentials
             let result = extension
                 .call(|extension, store| {
                     async move {
@@ -198,15 +229,15 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
                 })
                 .await;
 
+            // Update state
+            cx.update(|cx| {
+                state.update(cx, |state, _| {
+                    state.is_authenticated = false;
+                });
+            })?;
+
             match result {
-                Ok(Ok(Ok(()))) => {
-                    cx.update(|cx| {
-                        state.update(cx, |state, _| {
-                            state.is_authenticated = false;
-                        });
-                    })?;
-                    Ok(())
-                }
+                Ok(Ok(Ok(()))) => Ok(()),
                 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
                 Ok(Err(e)) => Err(e),
                 Err(e) => Err(e),
@@ -226,20 +257,302 @@ impl LanguageModelProviderState for ExtensionLanguageModelProvider {
         &self,
         cx: &mut Context<T>,
         callback: impl Fn(&mut T, &mut Context<T>) + 'static,
-    ) -> Option<gpui::Subscription> {
+    ) -> Option<Subscription> {
         Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
     }
 }
 
-struct EmptyConfigView;
+/// Configuration view for extension-based LLM providers.
+struct ExtensionProviderConfigurationView {
+    credential_key: String,
+    extension: WasmExtension,
+    extension_provider_id: String,
+    state: Entity<ExtensionLlmProviderState>,
+    settings_markdown: Option<Entity<Markdown>>,
+    api_key_editor: Entity<Editor>,
+    loading_settings: bool,
+    loading_credentials: bool,
+    _subscriptions: Vec<Subscription>,
+}
+
+impl ExtensionProviderConfigurationView {
+    fn new(
+        credential_key: String,
+        extension: WasmExtension,
+        extension_provider_id: String,
+        state: Entity<ExtensionLlmProviderState>,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Self {
+        // Subscribe to state changes
+        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
+            cx.notify();
+        });
+
+        // Create API key editor
+        let api_key_editor = cx.new(|cx| {
+            let mut editor = Editor::single_line(window, cx);
+            editor.set_placeholder_text("Enter API key...", window, cx);
+            editor
+        });
+
+        let mut this = Self {
+            credential_key,
+            extension,
+            extension_provider_id,
+            state,
+            settings_markdown: None,
+            api_key_editor,
+            loading_settings: true,
+            loading_credentials: true,
+            _subscriptions: vec![state_subscription],
+        };
+
+        // Load settings text from extension
+        this.load_settings_text(cx);
+
+        // Load existing credentials
+        this.load_credentials(cx);
+
+        this
+    }
+
+    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
+        let extension = self.extension.clone();
+        let provider_id = self.extension_provider_id.clone();
+
+        cx.spawn(async move |this, cx| {
+            let result = extension
+                .call({
+                    let provider_id = provider_id.clone();
+                    |ext, store| {
+                        async move {
+                            ext.call_llm_provider_settings_markdown(store, &provider_id)
+                                .await
+                        }
+                        .boxed()
+                    }
+                })
+                .await;
+
+            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
+
+            this.update(cx, |this, cx| {
+                this.loading_settings = false;
+                if let Some(text) = settings_text {
+                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
+                    this.settings_markdown = Some(markdown);
+                }
+                cx.notify();
+            })
+            .log_err();
+        })
+        .detach();
+    }
+
+    fn load_credentials(&mut self, cx: &mut Context<Self>) {
+        let credential_key = self.credential_key.clone();
+        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let state = self.state.clone();
+
+        cx.spawn(async move |this, cx| {
+            let credentials = credentials_provider
+                .read_credentials(&credential_key, cx)
+                .await
+                .log_err()
+                .flatten();
+
+            let has_credentials = credentials.is_some();
+
+            // Update authentication state based on stored credentials
+            let _ = cx.update(|cx| {
+                state.update(cx, |state, cx| {
+                    state.is_authenticated = has_credentials;
+                    cx.notify();
+                });
+            });
+
+            this.update(cx, |this, cx| {
+                this.loading_credentials = false;
+                cx.notify();
+            })
+            .log_err();
+        })
+        .detach();
+    }
+
+    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+        let api_key = self.api_key_editor.read(cx).text(cx);
+        if api_key.is_empty() {
+            return;
+        }
+
+        // Clear the editor
+        self.api_key_editor
+            .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+        let credential_key = self.credential_key.clone();
+        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let state = self.state.clone();
+
+        cx.spawn(async move |_this, cx| {
+            // Store in system keychain
+            credentials_provider
+                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
+                .await
+                .log_err();
+
+            // Update state to authenticated
+            let _ = cx.update(|cx| {
+                state.update(cx, |state, cx| {
+                    state.is_authenticated = true;
+                    cx.notify();
+                });
+            });
+        })
+        .detach();
+    }
+
+    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+        // Clear the editor
+        self.api_key_editor
+            .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+        let credential_key = self.credential_key.clone();
+        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let state = self.state.clone();
+
+        cx.spawn(async move |_this, cx| {
+            // Delete from system keychain
+            credentials_provider
+                .delete_credentials(&credential_key, cx)
+                .await
+                .log_err();
+
+            // Update state to unauthenticated
+            let _ = cx.update(|cx| {
+                state.update(cx, |state, cx| {
+                    state.is_authenticated = false;
+                    cx.notify();
+                });
+            });
+        })
+        .detach();
+    }
+
+    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
+        self.state.read(cx).is_authenticated
+    }
+}
+
+impl gpui::Render for ExtensionProviderConfigurationView {
+    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+        let is_loading = self.loading_settings || self.loading_credentials;
+        let is_authenticated = self.is_authenticated(cx);
+
+        if is_loading {
+            return v_flex()
+                .gap_2()
+                .child(Label::new("Loading...").color(Color::Muted))
+                .into_any_element();
+        }
+
+        let mut content = v_flex().gap_4().size_full();
+
+        // Render settings markdown if available
+        if let Some(markdown) = &self.settings_markdown {
+            let style = settings_markdown_style(_window, cx);
+            content = content.child(
+                div()
+                    .p_2()
+                    .rounded_md()
+                    .bg(cx.theme().colors().surface_background)
+                    .child(MarkdownElement::new(markdown.clone(), style)),
+            );
+        }
 
-impl gpui::Render for EmptyConfigView {
-    fn render(
-        &mut self,
-        _window: &mut Window,
-        _cx: &mut gpui::Context<Self>,
-    ) -> impl gpui::IntoElement {
-        gpui::Empty
+        // Render API key section
+        if is_authenticated {
+            content = content.child(
+                v_flex()
+                    .gap_2()
+                    .child(
+                        h_flex()
+                            .gap_2()
+                            .child(
+                                ui::Icon::new(ui::IconName::Check)
+                                    .color(Color::Success)
+                                    .size(ui::IconSize::Small),
+                            )
+                            .child(Label::new("API key configured").color(Color::Success)),
+                    )
+                    .child(
+                        ui::Button::new("reset-api-key", "Reset API Key")
+                            .style(ui::ButtonStyle::Subtle)
+                            .on_click(cx.listener(|this, _, window, cx| {
+                                this.reset_api_key(window, cx);
+                            })),
+                    ),
+            );
+        } else {
+            content = content.child(
+                v_flex()
+                    .gap_2()
+                    .on_action(cx.listener(Self::save_api_key))
+                    .child(
+                        Label::new("API Key")
+                            .size(LabelSize::Small)
+                            .color(Color::Muted),
+                    )
+                    .child(self.api_key_editor.clone())
+                    .child(
+                        Label::new("Enter your API key and press Enter to save")
+                            .size(LabelSize::Small)
+                            .color(Color::Muted),
+                    ),
+            );
+        }
+
+        content.into_any_element()
+    }
+}
+
+impl Focusable for ExtensionProviderConfigurationView {
+    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
+        self.api_key_editor.focus_handle(cx)
+    }
+}
+
+fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
+    let theme_settings = ThemeSettings::get_global(cx);
+    let colors = cx.theme().colors();
+    let mut text_style = window.text_style();
+    text_style.refine(&TextStyleRefinement {
+        font_family: Some(theme_settings.ui_font.family.clone()),
+        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
+        font_features: Some(theme_settings.ui_font.features.clone()),
+        color: Some(colors.text),
+        ..Default::default()
+    });
+
+    MarkdownStyle {
+        base_text_style: text_style,
+        selection_background_color: colors.element_selection_background,
+        inline_code: TextStyleRefinement {
+            background_color: Some(colors.editor_background),
+            ..Default::default()
+        },
+        link: TextStyleRefinement {
+            color: Some(colors.text_accent),
+            underline: Some(UnderlineStyle {
+                color: Some(colors.text_accent.opacity(0.5)),
+                thickness: px(1.),
+                ..Default::default()
+            }),
+            ..Default::default()
+        },
+        syntax: cx.theme().syntax().clone(),
+        ..Default::default()
     }
 }
 
@@ -254,7 +567,7 @@ pub struct ExtensionLanguageModel {
 
 impl LanguageModel for ExtensionLanguageModel {
     fn id(&self) -> LanguageModelId {
-        LanguageModelId::from(format!("{}:{}", self.provider_id.0, self.model_info.id))
+        LanguageModelId::from(self.model_info.id.clone())
     }
 
     fn name(&self) -> LanguageModelName {
@@ -270,7 +583,7 @@ impl LanguageModel for ExtensionLanguageModel {
     }
 
     fn telemetry_id(&self) -> String {
-        format!("extension:{}", self.model_info.id)
+        format!("extension-{}", self.model_info.id)
     }
 
     fn supports_images(&self) -> bool {
@@ -307,31 +620,33 @@ impl LanguageModel for ExtensionLanguageModel {
     fn count_tokens(
         &self,
         request: LanguageModelRequest,
-        _cx: &App,
+        cx: &App,
     ) -> BoxFuture<'static, Result<u64>> {
         let extension = self.extension.clone();
         let provider_id = self.provider_info.id.clone();
         let model_id = self.model_info.id.clone();
 
-        async move {
-            let wit_request = convert_request_to_wit(&request);
-
-            let result = extension
-                .call(|ext, store| {
-                    async move {
-                        ext.call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
-                            .await
+        let wit_request = convert_request_to_wit(request);
+
+        cx.background_spawn(async move {
+            extension
+                .call({
+                    let provider_id = provider_id.clone();
+                    let model_id = model_id.clone();
+                    let wit_request = wit_request.clone();
+                    |ext, store| {
+                        async move {
+                            let count = ext
+                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
+                                .await?
+                                .map_err(|e| anyhow!("{}", e))?;
+                            Ok(count)
+                        }
+                        .boxed()
                     }
-                    .boxed()
                 })
-                .await?;
-
-            match result {
-                Ok(Ok(count)) => Ok(count),
-                Ok(Err(e)) => Err(anyhow!("{}", e)),
-                Err(e) => Err(e),
-            }
-        }
+                .await?
+        })
         .boxed()
     }
 
@@ -350,68 +665,77 @@ impl LanguageModel for ExtensionLanguageModel {
         let provider_id = self.provider_info.id.clone();
         let model_id = self.model_info.id.clone();
 
-        async move {
-            let wit_request = convert_request_to_wit(&request);
+        let wit_request = convert_request_to_wit(request);
 
-            // Start the stream and get a stream ID
-            let outer_result = extension
-                .call(|ext, store| {
-                    async move {
-                        ext.call_llm_stream_completion_start(
-                            store,
-                            &provider_id,
-                            &model_id,
-                            &wit_request,
-                        )
-                        .await
+        async move {
+            // Start the stream
+            let stream_id = extension
+                .call({
+                    let provider_id = provider_id.clone();
+                    let model_id = model_id.clone();
+                    let wit_request = wit_request.clone();
+                    |ext, store| {
+                        async move {
+                            let id = ext
+                                .call_llm_stream_completion_start(
+                                    store,
+                                    &provider_id,
+                                    &model_id,
+                                    &wit_request,
+                                )
+                                .await?
+                                .map_err(|e| anyhow!("{}", e))?;
+                            Ok(id)
+                        }
+                        .boxed()
                     }
-                    .boxed()
                 })
                 .await
-                .map_err(|e| LanguageModelCompletionError::Other(e))?;
-
-            // Unwrap the inner Result<Result<String, String>>
-            let inner_result =
-                outer_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
-
-            // Get the stream ID
-            let stream_id =
-                inner_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
+                .map_err(LanguageModelCompletionError::Other)?
+                .map_err(LanguageModelCompletionError::Other)?;
 
             // Create a stream that polls for events
             let stream = futures::stream::unfold(
-                (extension, stream_id, false),
-                |(ext, stream_id, done)| async move {
+                (extension.clone(), stream_id, false),
+                move |(extension, stream_id, done)| async move {
                     if done {
                         return None;
                     }
 
-                    let result = ext
+                    let result = extension
                         .call({
                             let stream_id = stream_id.clone();
-                            move |ext, store| {
+                            |ext, store| {
                                 async move {
-                                    ext.call_llm_stream_completion_next(store, &stream_id).await
+                                    let event = ext
+                                        .call_llm_stream_completion_next(store, &stream_id)
+                                        .await?
+                                        .map_err(|e| anyhow!("{}", e))?;
+                                    Ok(event)
                                 }
                                 .boxed()
                             }
                         })
-                        .await;
+                        .await
+                        .and_then(|inner| inner);
 
                     match result {
-                        Ok(Ok(Ok(Some(event)))) => {
+                        Ok(Some(event)) => {
                             let converted = convert_completion_event(event);
-                            Some((Ok(converted), (ext, stream_id, false)))
+                            let is_done =
+                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
+                            Some((converted, (extension, stream_id, is_done)))
                         }
-                        Ok(Ok(Ok(None))) => {
-                            // Stream complete - close it
-                            let _ = ext
+                        Ok(None) => {
+                            // Stream complete, close it
+                            let _ = extension
                                 .call({
                                     let stream_id = stream_id.clone();
-                                    move |ext, store| {
+                                    |ext, store| {
                                         async move {
                                             ext.call_llm_stream_completion_close(store, &stream_id)
-                                                .await
+                                                .await?;
+                                            Ok::<(), anyhow::Error>(())
                                         }
                                         .boxed()
                                     }
@@ -419,63 +743,10 @@ impl LanguageModel for ExtensionLanguageModel {
                                 .await;
                             None
                         }
-                        Ok(Ok(Err(e))) => {
-                            // Extension returned an error - close stream and return error
-                            let _ = ext
-                                .call({
-                                    let stream_id = stream_id.clone();
-                                    move |ext, store| {
-                                        async move {
-                                            ext.call_llm_stream_completion_close(store, &stream_id)
-                                                .await
-                                        }
-                                        .boxed()
-                                    }
-                                })
-                                .await;
-                            Some((
-                                Err(LanguageModelCompletionError::Other(anyhow!("{}", e))),
-                                (ext, stream_id, true),
-                            ))
-                        }
-                        Ok(Err(e)) => {
-                            // WASM call error - close stream and return error
-                            let _ = ext
-                                .call({
-                                    let stream_id = stream_id.clone();
-                                    move |ext, store| {
-                                        async move {
-                                            ext.call_llm_stream_completion_close(store, &stream_id)
-                                                .await
-                                        }
-                                        .boxed()
-                                    }
-                                })
-                                .await;
-                            Some((
-                                Err(LanguageModelCompletionError::Other(e)),
-                                (ext, stream_id, true),
-                            ))
-                        }
-                        Err(e) => {
-                            // Channel error - close stream and return error
-                            let _ = ext
-                                .call({
-                                    let stream_id = stream_id.clone();
-                                    move |ext, store| {
-                                        async move {
-                                            ext.call_llm_stream_completion_close(store, &stream_id)
-                                                .await
-                                        }
-                                        .boxed()
-                                    }
-                                })
-                                .await;
-                            Some((
-                                Err(LanguageModelCompletionError::Other(e)),
-                                (ext, stream_id, true),
-                            ))
-                        }
+                        Err(e) => Some((
+                            Err(LanguageModelCompletionError::Other(e)),
+                            (extension, stream_id, true),
+                        )),
                     }
                 },
             );
@@ -486,87 +757,88 @@ impl LanguageModel for ExtensionLanguageModel {
     }
 
     fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+        // Extensions can implement this via llm_cache_configuration
         None
     }
 }
 
-fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionRequest {
-    let messages = request
+fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
+    use language_model::{MessageContent, Role};
+
+    let messages: Vec<LlmRequestMessage> = request
         .messages
-        .iter()
-        .map(|msg| LlmRequestMessage {
-            role: match msg.role {
-                language_model::Role::User => LlmMessageRole::User,
-                language_model::Role::Assistant => LlmMessageRole::Assistant,
-                language_model::Role::System => LlmMessageRole::System,
-            },
-            content: msg
+        .into_iter()
+        .map(|msg| {
+            let role = match msg.role {
+                Role::User => LlmMessageRole::User,
+                Role::Assistant => LlmMessageRole::Assistant,
+                Role::System => LlmMessageRole::System,
+            };
+
+            let content: Vec<LlmMessageContent> = msg
                 .content
-                .iter()
-                .map(|content| match content {
-                    language_model::MessageContent::Text(text) => {
-                        LlmMessageContent::Text(text.clone())
-                    }
-                    language_model::MessageContent::Image(image) => {
-                        LlmMessageContent::Image(LlmImageData {
-                            source: image.source.to_string(),
-                            width: Some(image.size.width.0 as u32),
-                            height: Some(image.size.height.0 as u32),
-                        })
-                    }
-                    language_model::MessageContent::ToolUse(tool_use) => {
-                        LlmMessageContent::ToolUse(LlmToolUse {
-                            id: tool_use.id.to_string(),
-                            name: tool_use.name.to_string(),
-                            input: tool_use.raw_input.clone(),
-                            thought_signature: tool_use.thought_signature.clone(),
-                        })
-                    }
-                    language_model::MessageContent::ToolResult(result) => {
+                .into_iter()
+                .map(|c| match c {
+                    MessageContent::Text(text) => LlmMessageContent::Text(text),
+                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
+                        source: image.source.to_string(),
+                        width: Some(image.size.width.0 as u32),
+                        height: Some(image.size.height.0 as u32),
+                    }),
+                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
+                        id: tool_use.id.to_string(),
+                        name: tool_use.name.to_string(),
+                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
+                        thought_signature: tool_use.thought_signature,
+                    }),
+                    MessageContent::ToolResult(tool_result) => {
+                        let content = match tool_result.content {
+                            language_model::LanguageModelToolResultContent::Text(text) => {
+                                LlmToolResultContent::Text(text.to_string())
+                            }
+                            language_model::LanguageModelToolResultContent::Image(image) => {
+                                LlmToolResultContent::Image(LlmImageData {
+                                    source: image.source.to_string(),
+                                    width: Some(image.size.width.0 as u32),
+                                    height: Some(image.size.height.0 as u32),
+                                })
+                            }
+                        };
                         LlmMessageContent::ToolResult(LlmToolResult {
-                            tool_use_id: result.tool_use_id.to_string(),
-                            tool_name: result.tool_name.to_string(),
-                            is_error: result.is_error,
-                            content: match &result.content {
-                                language_model::LanguageModelToolResultContent::Text(t) => {
-                                    LlmToolResultContent::Text(t.to_string())
-                                }
-                                language_model::LanguageModelToolResultContent::Image(img) => {
-                                    LlmToolResultContent::Image(LlmImageData {
-                                        source: img.source.to_string(),
-                                        width: Some(img.size.width.0 as u32),
-                                        height: Some(img.size.height.0 as u32),
-                                    })
-                                }
-                            },
+                            tool_use_id: tool_result.tool_use_id.to_string(),
+                            tool_name: tool_result.tool_name.to_string(),
+                            is_error: tool_result.is_error,
+                            content,
                         })
                     }
-                    language_model::MessageContent::Thinking { text, signature } => {
-                        LlmMessageContent::Thinking(LlmThinkingContent {
-                            text: text.clone(),
-                            signature: signature.clone(),
-                        })
+                    MessageContent::Thinking { text, signature } => {
+                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
                     }
-                    language_model::MessageContent::RedactedThinking(data) => {
-                        LlmMessageContent::RedactedThinking(data.clone())
+                    MessageContent::RedactedThinking(data) => {
+                        LlmMessageContent::RedactedThinking(data)
                     }
                 })
-                .collect(),
-            cache: msg.cache,
+                .collect();
+
+            LlmRequestMessage {
+                role,
+                content,
+                cache: msg.cache,
+            }
         })
         .collect();
 
-    let tools = request
+    let tools: Vec<LlmToolDefinition> = request
         .tools
-        .iter()
+        .into_iter()
         .map(|tool| LlmToolDefinition {
-            name: tool.name.clone(),
-            description: tool.description.clone(),
+            name: tool.name,
+            description: tool.description,
             input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
         })
         .collect();
 
-    let tool_choice = request.tool_choice.as_ref().map(|choice| match choice {
+    let tool_choice = request.tool_choice.map(|tc| match tc {
         LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
         LanguageModelToolChoice::Any => LlmToolChoice::Any,
         LanguageModelToolChoice::None => LlmToolChoice::None,
@@ -576,58 +848,71 @@ fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionReques
         messages,
         tools,
         tool_choice,
-        stop_sequences: request.stop.clone(),
+        stop_sequences: request.stop,
         temperature: request.temperature,
-        thinking_allowed: request.thinking_allowed,
+        thinking_allowed: false,
         max_tokens: None,
     }
 }
 
-fn convert_completion_event(event: LlmCompletionEvent) -> LanguageModelCompletionEvent {
+fn convert_completion_event(
+    event: LlmCompletionEvent,
+) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
     match event {
-        LlmCompletionEvent::Started => LanguageModelCompletionEvent::Started,
-        LlmCompletionEvent::Text(text) => LanguageModelCompletionEvent::Text(text),
-        LlmCompletionEvent::Thinking(thinking) => LanguageModelCompletionEvent::Thinking {
+        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
+            message_id: String::new(),
+        }),
+        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
+        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
             text: thinking.text,
             signature: thinking.signature,
-        },
+        }),
         LlmCompletionEvent::RedactedThinking(data) => {
-            LanguageModelCompletionEvent::RedactedThinking { data }
+            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
         }
         LlmCompletionEvent::ToolUse(tool_use) => {
-            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
-                id: LanguageModelToolUseId::from(tool_use.id),
-                name: tool_use.name.into(),
-                raw_input: tool_use.input.clone(),
-                input: serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null),
-                is_input_complete: true,
-                thought_signature: tool_use.thought_signature,
-            })
+            let raw_input = tool_use.input.clone();
+            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
+            Ok(LanguageModelCompletionEvent::ToolUse(
+                LanguageModelToolUse {
+                    id: LanguageModelToolUseId::from(tool_use.id),
+                    name: tool_use.name.into(),
+                    raw_input,
+                    input,
+                    is_input_complete: true,
+                    thought_signature: tool_use.thought_signature,
+                },
+            ))
         }
         LlmCompletionEvent::ToolUseJsonParseError(error) => {
-            LanguageModelCompletionEvent::ToolUseJsonParseError {
+            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
                 id: LanguageModelToolUseId::from(error.id),
                 tool_name: error.tool_name.into(),
                 raw_input: error.raw_input.into(),
                 json_parse_error: error.error,
-            }
+            })
+        }
+        LlmCompletionEvent::Stop(reason) => {
+            let stop_reason = match reason {
+                LlmStopReason::EndTurn => StopReason::EndTurn,
+                LlmStopReason::MaxTokens => StopReason::MaxTokens,
+                LlmStopReason::ToolUse => StopReason::ToolUse,
+                LlmStopReason::Refusal => StopReason::Refusal,
+            };
+            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
+        }
+        LlmCompletionEvent::Usage(usage) => {
+            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                input_tokens: usage.input_tokens,
+                output_tokens: usage.output_tokens,
+                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
+                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
+            }))
         }
-        LlmCompletionEvent::Stop(reason) => LanguageModelCompletionEvent::Stop(match reason {
-            LlmStopReason::EndTurn => StopReason::EndTurn,
-            LlmStopReason::MaxTokens => StopReason::MaxTokens,
-            LlmStopReason::ToolUse => StopReason::ToolUse,
-            LlmStopReason::Refusal => StopReason::Refusal,
-        }),
-        LlmCompletionEvent::Usage(usage) => LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
-            input_tokens: usage.input_tokens,
-            output_tokens: usage.output_tokens,
-            cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
-            cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
-        }),
         LlmCompletionEvent::ReasoningDetails(json) => {
-            LanguageModelCompletionEvent::ReasoningDetails(
+            Ok(LanguageModelCompletionEvent::ReasoningDetails(
                 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
-            )
+            ))
         }
     }
 }

crates/extension_host/src/wasm_host/wit.rs 🔗

@@ -1199,6 +1199,20 @@ impl Extension {
         }
     }
 
+    pub async fn call_llm_provider_settings_markdown(
+        &self,
+        store: &mut Store<WasmState>,
+        provider_id: &str,
+    ) -> Result<Option<String>> {
+        match self {
+            Extension::V0_7_0(ext) => {
+                ext.call_llm_provider_settings_markdown(store, provider_id)
+                    .await
+            }
+            _ => Ok(None),
+        }
+    }
+
     pub async fn call_llm_provider_is_authenticated(
         &self,
         store: &mut Store<WasmState>,