diff --git a/Cargo.lock b/Cargo.lock index 4a1cf2518639b2a616bd6fb4cf4c44b3b25ec1e2..8ac867d8a78cfcf2048abdef882e778810a282b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5905,9 +5905,11 @@ dependencies = [ "async-trait", "client", "collections", + "credentials_provider", "criterion", "ctor", "dap", + "editor", "extension", "fs", "futures 0.3.31", @@ -5919,6 +5921,8 @@ dependencies = [ "language_model", "log", "lsp", + "markdown", + "menu", "moka", "node_runtime", "parking_lot", diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index 5b548b8e45f282e63e28b0931a2793d4fd8bace7..daafb63c278cacca0a9275d8e4e9db22cef209d0 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -288,6 +288,12 @@ pub trait Extension: Send + Sync { Ok(Vec::new()) } + /// Returns markdown content to display in the provider's settings UI. + /// This can include setup instructions, links to documentation, etc. + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + None + } + /// Check if the provider is authenticated. fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { false @@ -618,6 +624,10 @@ impl wit::Guest for Component { extension().llm_provider_models(&provider_id) } + fn llm_provider_settings_markdown(provider_id: String) -> Option { + extension().llm_provider_settings_markdown(&provider_id) + } + fn llm_provider_is_authenticated(provider_id: String) -> bool { extension().llm_provider_is_authenticated(&provider_id) } diff --git a/crates/extension_api/wit/since_v0.7.0/extension.wit b/crates/extension_api/wit/since_v0.7.0/extension.wit index 265bb922a43c03e3d7bfb5b688c919022a0a6dd1..92979a8780039776853fa250be2afdb204ae5d55 100644 --- a/crates/extension_api/wit/since_v0.7.0/extension.wit +++ b/crates/extension_api/wit/since_v0.7.0/extension.wit @@ -180,6 +180,10 @@ world extension { /// Returns the models available for a provider. export llm-provider-models: func(provider-id: string) -> result, string>; + /// Returns markdown content to display in the provider's settings UI. + /// This can include setup instructions, links to documentation, etc. + export llm-provider-settings-markdown: func(provider-id: string) -> option; + /// Check if the provider is authenticated. export llm-provider-is-authenticated: func(provider-id: string) -> bool; diff --git a/crates/extension_host/Cargo.toml b/crates/extension_host/Cargo.toml index 46c481ee53babe32598f434e4c0e346bd7b1ab09..a5c9357b9c80b70f0bf362ba04cd581d52f67828 100644 --- a/crates/extension_host/Cargo.toml +++ b/crates/extension_host/Cargo.toml @@ -22,7 +22,9 @@ async-tar.workspace = true async-trait.workspace = true client.workspace = true collections.workspace = true +credentials_provider.workspace = true dap.workspace = true +editor.workspace = true extension.workspace = true fs.workspace = true futures.workspace = true @@ -32,7 +34,9 @@ http_client.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +markdown.workspace = true lsp.workspace = true +menu.workspace = true moka.workspace = true node_runtime.workspace = true paths.workspace = true @@ -47,6 +51,7 @@ settings.workspace = true task.workspace = true telemetry.workspace = true tempfile.workspace = true +theme.workspace = true toml.workspace = true ui.workspace = true url.workspace = true diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index 02cc9722f3ca8af3537a7977b37a39895cc0e278..e55e8b05938f4a49e625165f893c82e51714bd34 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -7,10 +7,16 @@ use crate::wasm_host::wit::{ LlmToolUse, }; use anyhow::{Result, anyhow}; +use credentials_provider::CredentialsProvider; +use editor::Editor; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; -use gpui::{AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, Window}; +use gpui::Focusable; +use gpui::{ + AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, + TextStyleRefinement, UnderlineStyle, Window, px, +}; use language_model::tool_schema::LanguageModelToolSchemaFormat; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, @@ -19,7 +25,12 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage, }; +use markdown::{Markdown, MarkdownElement, MarkdownStyle}; +use settings::Settings; use std::sync::Arc; +use theme::ThemeSettings; +use ui::{Label, LabelSize, prelude::*}; +use util::ResultExt as _; /// An extension-based language model provider. pub struct ExtensionLanguageModelProvider { @@ -58,13 +69,16 @@ impl ExtensionLanguageModelProvider { fn provider_id_string(&self) -> String { format!("{}:{}", self.extension.manifest.id, self.provider_info.id) } + + /// The credential key used for storing the API key in the system keychain. + fn credential_key(&self) -> String { + format!("extension-llm-{}", self.provider_id_string()) + } } impl LanguageModelProvider for ExtensionLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - let id = LanguageModelProviderId::from(self.provider_id_string()); - eprintln!("ExtensionLanguageModelProvider::id() -> {:?}", id); - id + LanguageModelProviderId::from(self.provider_id_string()) } fn name(&self) -> LanguageModelProviderName { @@ -99,8 +113,6 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { .available_models .iter() .find(|m| m.is_default_fast) - .or_else(|| state.available_models.iter().find(|m| m.is_default)) - .or_else(|| state.available_models.first()) .map(|model_info| { Arc::new(ExtensionLanguageModel { extension: self.extension.clone(), @@ -114,16 +126,10 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { fn provided_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - eprintln!( - "ExtensionLanguageModelProvider::provided_models called for {}, returning {} models", - self.provider_info.name, - state.available_models.len() - ); state .available_models .iter() .map(|model_info| { - eprintln!(" - model: {}", model_info.name); Arc::new(ExtensionLanguageModel { extension: self.extension.clone(), model_info: model_info.clone(), @@ -175,18 +181,43 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { fn configuration_view( &self, _target_agent: ConfigurationViewTargetAgent, - _window: &mut Window, + window: &mut Window, cx: &mut App, ) -> AnyView { - cx.new(|_| EmptyConfigView).into() + let credential_key = self.credential_key(); + let extension = self.extension.clone(); + let extension_provider_id = self.provider_info.id.clone(); + let state = self.state.clone(); + + cx.new(|cx| { + ExtensionProviderConfigurationView::new( + credential_key, + extension, + extension_provider_id, + state, + window, + cx, + ) + }) + .into() } fn reset_credentials(&self, cx: &mut App) -> Task> { let extension = self.extension.clone(); let provider_id = self.provider_info.id.clone(); let state = self.state.clone(); + let credential_key = self.credential_key(); + + let credentials_provider = ::global(cx); cx.spawn(async move |cx| { + // Delete from system keychain + credentials_provider + .delete_credentials(&credential_key, cx) + .await + .log_err(); + + // Call extension's reset_credentials let result = extension .call(|extension, store| { async move { @@ -198,15 +229,15 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { }) .await; + // Update state + cx.update(|cx| { + state.update(cx, |state, _| { + state.is_authenticated = false; + }); + })?; + match result { - Ok(Ok(Ok(()))) => { - cx.update(|cx| { - state.update(cx, |state, _| { - state.is_authenticated = false; - }); - })?; - Ok(()) - } + Ok(Ok(Ok(()))) => Ok(()), Ok(Ok(Err(e))) => Err(anyhow!("{}", e)), Ok(Err(e)) => Err(e), Err(e) => Err(e), @@ -226,20 +257,302 @@ impl LanguageModelProviderState for ExtensionLanguageModelProvider { &self, cx: &mut Context, callback: impl Fn(&mut T, &mut Context) + 'static, - ) -> Option { + ) -> Option { Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx))) } } -struct EmptyConfigView; +/// Configuration view for extension-based LLM providers. +struct ExtensionProviderConfigurationView { + credential_key: String, + extension: WasmExtension, + extension_provider_id: String, + state: Entity, + settings_markdown: Option>, + api_key_editor: Entity, + loading_settings: bool, + loading_credentials: bool, + _subscriptions: Vec, +} + +impl ExtensionProviderConfigurationView { + fn new( + credential_key: String, + extension: WasmExtension, + extension_provider_id: String, + state: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + // Subscribe to state changes + let state_subscription = cx.subscribe(&state, |_, _, _, cx| { + cx.notify(); + }); + + // Create API key editor + let api_key_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("Enter API key...", window, cx); + editor + }); + + let mut this = Self { + credential_key, + extension, + extension_provider_id, + state, + settings_markdown: None, + api_key_editor, + loading_settings: true, + loading_credentials: true, + _subscriptions: vec![state_subscription], + }; + + // Load settings text from extension + this.load_settings_text(cx); + + // Load existing credentials + this.load_credentials(cx); + + this + } + + fn load_settings_text(&mut self, cx: &mut Context) { + let extension = self.extension.clone(); + let provider_id = self.extension_provider_id.clone(); + + cx.spawn(async move |this, cx| { + let result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_settings_markdown(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let settings_text = result.ok().and_then(|inner| inner.ok()).flatten(); + + this.update(cx, |this, cx| { + this.loading_settings = false; + if let Some(text) = settings_text { + let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx)); + this.settings_markdown = Some(markdown); + } + cx.notify(); + }) + .log_err(); + }) + .detach(); + } + + fn load_credentials(&mut self, cx: &mut Context) { + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |this, cx| { + let credentials = credentials_provider + .read_credentials(&credential_key, cx) + .await + .log_err() + .flatten(); + + let has_credentials = credentials.is_some(); + + // Update authentication state based on stored credentials + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = has_credentials; + cx.notify(); + }); + }); + + this.update(cx, |this, cx| { + this.loading_credentials = false; + cx.notify(); + }) + .log_err(); + }) + .detach(); + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + // Clear the editor + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |_this, cx| { + // Store in system keychain + credentials_provider + .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx) + .await + .log_err(); + + // Update state to authenticated + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = true; + cx.notify(); + }); + }); + }) + .detach(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + // Clear the editor + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |_this, cx| { + // Delete from system keychain + credentials_provider + .delete_credentials(&credential_key, cx) + .await + .log_err(); + + // Update state to unauthenticated + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = false; + cx.notify(); + }); + }); + }) + .detach(); + } + + fn is_authenticated(&self, cx: &Context) -> bool { + self.state.read(cx).is_authenticated + } +} + +impl gpui::Render for ExtensionProviderConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_loading = self.loading_settings || self.loading_credentials; + let is_authenticated = self.is_authenticated(cx); + + if is_loading { + return v_flex() + .gap_2() + .child(Label::new("Loading...").color(Color::Muted)) + .into_any_element(); + } + + let mut content = v_flex().gap_4().size_full(); + + // Render settings markdown if available + if let Some(markdown) = &self.settings_markdown { + let style = settings_markdown_style(_window, cx); + content = content.child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().surface_background) + .child(MarkdownElement::new(markdown.clone(), style)), + ); + } -impl gpui::Render for EmptyConfigView { - fn render( - &mut self, - _window: &mut Window, - _cx: &mut gpui::Context, - ) -> impl gpui::IntoElement { - gpui::Empty + // Render API key section + if is_authenticated { + content = content.child( + v_flex() + .gap_2() + .child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Check) + .color(Color::Success) + .size(ui::IconSize::Small), + ) + .child(Label::new("API key configured").color(Color::Success)), + ) + .child( + ui::Button::new("reset-api-key", "Reset API Key") + .style(ui::ButtonStyle::Subtle) + .on_click(cx.listener(|this, _, window, cx| { + this.reset_api_key(window, cx); + })), + ), + ); + } else { + content = content.child( + v_flex() + .gap_2() + .on_action(cx.listener(Self::save_api_key)) + .child( + Label::new("API Key") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new("Enter your API key and press Enter to save") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ); + } + + content.into_any_element() + } +} + +impl Focusable for ExtensionProviderConfigurationView { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.api_key_editor.focus_handle(cx) + } +} + +fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + let mut text_style = window.text_style(); + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + font_fallbacks: theme_settings.ui_font.fallbacks.clone(), + font_features: Some(theme_settings.ui_font.features.clone()), + color: Some(colors.text), + ..Default::default() + }); + + MarkdownStyle { + base_text_style: text_style, + selection_background_color: colors.element_selection_background, + inline_code: TextStyleRefinement { + background_color: Some(colors.editor_background), + ..Default::default() + }, + link: TextStyleRefinement { + color: Some(colors.text_accent), + underline: Some(UnderlineStyle { + color: Some(colors.text_accent.opacity(0.5)), + thickness: px(1.), + ..Default::default() + }), + ..Default::default() + }, + syntax: cx.theme().syntax().clone(), + ..Default::default() } } @@ -254,7 +567,7 @@ pub struct ExtensionLanguageModel { impl LanguageModel for ExtensionLanguageModel { fn id(&self) -> LanguageModelId { - LanguageModelId::from(format!("{}:{}", self.provider_id.0, self.model_info.id)) + LanguageModelId::from(self.model_info.id.clone()) } fn name(&self) -> LanguageModelName { @@ -270,7 +583,7 @@ impl LanguageModel for ExtensionLanguageModel { } fn telemetry_id(&self) -> String { - format!("extension:{}", self.model_info.id) + format!("extension-{}", self.model_info.id) } fn supports_images(&self) -> bool { @@ -307,31 +620,33 @@ impl LanguageModel for ExtensionLanguageModel { fn count_tokens( &self, request: LanguageModelRequest, - _cx: &App, + cx: &App, ) -> BoxFuture<'static, Result> { let extension = self.extension.clone(); let provider_id = self.provider_info.id.clone(); let model_id = self.model_info.id.clone(); - async move { - let wit_request = convert_request_to_wit(&request); - - let result = extension - .call(|ext, store| { - async move { - ext.call_llm_count_tokens(store, &provider_id, &model_id, &wit_request) - .await + let wit_request = convert_request_to_wit(request); + + cx.background_spawn(async move { + extension + .call({ + let provider_id = provider_id.clone(); + let model_id = model_id.clone(); + let wit_request = wit_request.clone(); + |ext, store| { + async move { + let count = ext + .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request) + .await? + .map_err(|e| anyhow!("{}", e))?; + Ok(count) + } + .boxed() } - .boxed() }) - .await?; - - match result { - Ok(Ok(count)) => Ok(count), - Ok(Err(e)) => Err(anyhow!("{}", e)), - Err(e) => Err(e), - } - } + .await? + }) .boxed() } @@ -350,68 +665,77 @@ impl LanguageModel for ExtensionLanguageModel { let provider_id = self.provider_info.id.clone(); let model_id = self.model_info.id.clone(); - async move { - let wit_request = convert_request_to_wit(&request); + let wit_request = convert_request_to_wit(request); - // Start the stream and get a stream ID - let outer_result = extension - .call(|ext, store| { - async move { - ext.call_llm_stream_completion_start( - store, - &provider_id, - &model_id, - &wit_request, - ) - .await + async move { + // Start the stream + let stream_id = extension + .call({ + let provider_id = provider_id.clone(); + let model_id = model_id.clone(); + let wit_request = wit_request.clone(); + |ext, store| { + async move { + let id = ext + .call_llm_stream_completion_start( + store, + &provider_id, + &model_id, + &wit_request, + ) + .await? + .map_err(|e| anyhow!("{}", e))?; + Ok(id) + } + .boxed() } - .boxed() }) .await - .map_err(|e| LanguageModelCompletionError::Other(e))?; - - // Unwrap the inner Result> - let inner_result = - outer_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?; - - // Get the stream ID - let stream_id = - inner_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?; + .map_err(LanguageModelCompletionError::Other)? + .map_err(LanguageModelCompletionError::Other)?; // Create a stream that polls for events let stream = futures::stream::unfold( - (extension, stream_id, false), - |(ext, stream_id, done)| async move { + (extension.clone(), stream_id, false), + move |(extension, stream_id, done)| async move { if done { return None; } - let result = ext + let result = extension .call({ let stream_id = stream_id.clone(); - move |ext, store| { + |ext, store| { async move { - ext.call_llm_stream_completion_next(store, &stream_id).await + let event = ext + .call_llm_stream_completion_next(store, &stream_id) + .await? + .map_err(|e| anyhow!("{}", e))?; + Ok(event) } .boxed() } }) - .await; + .await + .and_then(|inner| inner); match result { - Ok(Ok(Ok(Some(event)))) => { + Ok(Some(event)) => { let converted = convert_completion_event(event); - Some((Ok(converted), (ext, stream_id, false))) + let is_done = + matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_))); + Some((converted, (extension, stream_id, is_done))) } - Ok(Ok(Ok(None))) => { - // Stream complete - close it - let _ = ext + Ok(None) => { + // Stream complete, close it + let _ = extension .call({ let stream_id = stream_id.clone(); - move |ext, store| { + |ext, store| { async move { ext.call_llm_stream_completion_close(store, &stream_id) - .await + .await?; + Ok::<(), anyhow::Error>(()) } .boxed() } @@ -419,63 +743,10 @@ impl LanguageModel for ExtensionLanguageModel { .await; None } - Ok(Ok(Err(e))) => { - // Extension returned an error - close stream and return error - let _ = ext - .call({ - let stream_id = stream_id.clone(); - move |ext, store| { - async move { - ext.call_llm_stream_completion_close(store, &stream_id) - .await - } - .boxed() - } - }) - .await; - Some(( - Err(LanguageModelCompletionError::Other(anyhow!("{}", e))), - (ext, stream_id, true), - )) - } - Ok(Err(e)) => { - // WASM call error - close stream and return error - let _ = ext - .call({ - let stream_id = stream_id.clone(); - move |ext, store| { - async move { - ext.call_llm_stream_completion_close(store, &stream_id) - .await - } - .boxed() - } - }) - .await; - Some(( - Err(LanguageModelCompletionError::Other(e)), - (ext, stream_id, true), - )) - } - Err(e) => { - // Channel error - close stream and return error - let _ = ext - .call({ - let stream_id = stream_id.clone(); - move |ext, store| { - async move { - ext.call_llm_stream_completion_close(store, &stream_id) - .await - } - .boxed() - } - }) - .await; - Some(( - Err(LanguageModelCompletionError::Other(e)), - (ext, stream_id, true), - )) - } + Err(e) => Some(( + Err(LanguageModelCompletionError::Other(e)), + (extension, stream_id, true), + )), } }, ); @@ -486,87 +757,88 @@ impl LanguageModel for ExtensionLanguageModel { } fn cache_configuration(&self) -> Option { + // Extensions can implement this via llm_cache_configuration None } } -fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionRequest { - let messages = request +fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest { + use language_model::{MessageContent, Role}; + + let messages: Vec = request .messages - .iter() - .map(|msg| LlmRequestMessage { - role: match msg.role { - language_model::Role::User => LlmMessageRole::User, - language_model::Role::Assistant => LlmMessageRole::Assistant, - language_model::Role::System => LlmMessageRole::System, - }, - content: msg + .into_iter() + .map(|msg| { + let role = match msg.role { + Role::User => LlmMessageRole::User, + Role::Assistant => LlmMessageRole::Assistant, + Role::System => LlmMessageRole::System, + }; + + let content: Vec = msg .content - .iter() - .map(|content| match content { - language_model::MessageContent::Text(text) => { - LlmMessageContent::Text(text.clone()) - } - language_model::MessageContent::Image(image) => { - LlmMessageContent::Image(LlmImageData { - source: image.source.to_string(), - width: Some(image.size.width.0 as u32), - height: Some(image.size.height.0 as u32), - }) - } - language_model::MessageContent::ToolUse(tool_use) => { - LlmMessageContent::ToolUse(LlmToolUse { - id: tool_use.id.to_string(), - name: tool_use.name.to_string(), - input: tool_use.raw_input.clone(), - thought_signature: tool_use.thought_signature.clone(), - }) - } - language_model::MessageContent::ToolResult(result) => { + .into_iter() + .map(|c| match c { + MessageContent::Text(text) => LlmMessageContent::Text(text), + MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData { + source: image.source.to_string(), + width: Some(image.size.width.0 as u32), + height: Some(image.size.height.0 as u32), + }), + MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: serde_json::to_string(&tool_use.input).unwrap_or_default(), + thought_signature: tool_use.thought_signature, + }), + MessageContent::ToolResult(tool_result) => { + let content = match tool_result.content { + language_model::LanguageModelToolResultContent::Text(text) => { + LlmToolResultContent::Text(text.to_string()) + } + language_model::LanguageModelToolResultContent::Image(image) => { + LlmToolResultContent::Image(LlmImageData { + source: image.source.to_string(), + width: Some(image.size.width.0 as u32), + height: Some(image.size.height.0 as u32), + }) + } + }; LlmMessageContent::ToolResult(LlmToolResult { - tool_use_id: result.tool_use_id.to_string(), - tool_name: result.tool_name.to_string(), - is_error: result.is_error, - content: match &result.content { - language_model::LanguageModelToolResultContent::Text(t) => { - LlmToolResultContent::Text(t.to_string()) - } - language_model::LanguageModelToolResultContent::Image(img) => { - LlmToolResultContent::Image(LlmImageData { - source: img.source.to_string(), - width: Some(img.size.width.0 as u32), - height: Some(img.size.height.0 as u32), - }) - } - }, + tool_use_id: tool_result.tool_use_id.to_string(), + tool_name: tool_result.tool_name.to_string(), + is_error: tool_result.is_error, + content, }) } - language_model::MessageContent::Thinking { text, signature } => { - LlmMessageContent::Thinking(LlmThinkingContent { - text: text.clone(), - signature: signature.clone(), - }) + MessageContent::Thinking { text, signature } => { + LlmMessageContent::Thinking(LlmThinkingContent { text, signature }) } - language_model::MessageContent::RedactedThinking(data) => { - LlmMessageContent::RedactedThinking(data.clone()) + MessageContent::RedactedThinking(data) => { + LlmMessageContent::RedactedThinking(data) } }) - .collect(), - cache: msg.cache, + .collect(); + + LlmRequestMessage { + role, + content, + cache: msg.cache, + } }) .collect(); - let tools = request + let tools: Vec = request .tools - .iter() + .into_iter() .map(|tool| LlmToolDefinition { - name: tool.name.clone(), - description: tool.description.clone(), + name: tool.name, + description: tool.description, input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(), }) .collect(); - let tool_choice = request.tool_choice.as_ref().map(|choice| match choice { + let tool_choice = request.tool_choice.map(|tc| match tc { LanguageModelToolChoice::Auto => LlmToolChoice::Auto, LanguageModelToolChoice::Any => LlmToolChoice::Any, LanguageModelToolChoice::None => LlmToolChoice::None, @@ -576,58 +848,71 @@ fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionReques messages, tools, tool_choice, - stop_sequences: request.stop.clone(), + stop_sequences: request.stop, temperature: request.temperature, - thinking_allowed: request.thinking_allowed, + thinking_allowed: false, max_tokens: None, } } -fn convert_completion_event(event: LlmCompletionEvent) -> LanguageModelCompletionEvent { +fn convert_completion_event( + event: LlmCompletionEvent, +) -> Result { match event { - LlmCompletionEvent::Started => LanguageModelCompletionEvent::Started, - LlmCompletionEvent::Text(text) => LanguageModelCompletionEvent::Text(text), - LlmCompletionEvent::Thinking(thinking) => LanguageModelCompletionEvent::Thinking { + LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage { + message_id: String::new(), + }), + LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)), + LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking { text: thinking.text, signature: thinking.signature, - }, + }), LlmCompletionEvent::RedactedThinking(data) => { - LanguageModelCompletionEvent::RedactedThinking { data } + Ok(LanguageModelCompletionEvent::RedactedThinking { data }) } LlmCompletionEvent::ToolUse(tool_use) => { - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - id: LanguageModelToolUseId::from(tool_use.id), - name: tool_use.name.into(), - raw_input: tool_use.input.clone(), - input: serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null), - is_input_complete: true, - thought_signature: tool_use.thought_signature, - }) + let raw_input = tool_use.input.clone(); + let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null); + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(tool_use.id), + name: tool_use.name.into(), + raw_input, + input, + is_input_complete: true, + thought_signature: tool_use.thought_signature, + }, + )) } LlmCompletionEvent::ToolUseJsonParseError(error) => { - LanguageModelCompletionEvent::ToolUseJsonParseError { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { id: LanguageModelToolUseId::from(error.id), tool_name: error.tool_name.into(), raw_input: error.raw_input.into(), json_parse_error: error.error, - } + }) + } + LlmCompletionEvent::Stop(reason) => { + let stop_reason = match reason { + LlmStopReason::EndTurn => StopReason::EndTurn, + LlmStopReason::MaxTokens => StopReason::MaxTokens, + LlmStopReason::ToolUse => StopReason::ToolUse, + LlmStopReason::Refusal => StopReason::Refusal, + }; + Ok(LanguageModelCompletionEvent::Stop(stop_reason)) + } + LlmCompletionEvent::Usage(usage) => { + Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + })) } - LlmCompletionEvent::Stop(reason) => LanguageModelCompletionEvent::Stop(match reason { - LlmStopReason::EndTurn => StopReason::EndTurn, - LlmStopReason::MaxTokens => StopReason::MaxTokens, - LlmStopReason::ToolUse => StopReason::ToolUse, - LlmStopReason::Refusal => StopReason::Refusal, - }), - LlmCompletionEvent::Usage(usage) => LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), - cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), - }), LlmCompletionEvent::ReasoningDetails(json) => { - LanguageModelCompletionEvent::ReasoningDetails( + Ok(LanguageModelCompletionEvent::ReasoningDetails( serde_json::from_str(&json).unwrap_or(serde_json::Value::Null), - ) + )) } } } diff --git a/crates/extension_host/src/wasm_host/wit.rs b/crates/extension_host/src/wasm_host/wit.rs index a18ad1a10803bbcec217ff72ef5847432c41ebbe..ec178b035c50e586a0278762844acef16ea424ff 100644 --- a/crates/extension_host/src/wasm_host/wit.rs +++ b/crates/extension_host/src/wasm_host/wit.rs @@ -1199,6 +1199,20 @@ impl Extension { } } + pub async fn call_llm_provider_settings_markdown( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_7_0(ext) => { + ext.call_llm_provider_settings_markdown(store, provider_id) + .await + } + _ => Ok(None), + } + } + pub async fn call_llm_provider_is_authenticated( &self, store: &mut Store,