From 488fa0254772b72709875e37802cef0955f67e26 Mon Sep 17 00:00:00 2001 From: Michael Benfield Date: Sat, 13 Dec 2025 19:22:20 -0800 Subject: [PATCH] Streaming tool use for inline assistant (#44751) Depends on: https://github.com/zed-industries/zed/pull/44753 Release Notes: - N/A --------- Co-authored-by: Mikayla Maki --- assets/prompts/content_prompt_v2.hbs | 3 +- assets/settings/default.json | 2 + crates/agent_settings/src/agent_settings.rs | 4 + crates/agent_ui/src/agent_ui.rs | 1 + crates/agent_ui/src/buffer_codegen.rs | 304 +++++++++++++----- crates/agent_ui/src/inline_assistant.rs | 52 --- crates/anthropic/src/anthropic.rs | 14 + crates/feature_flags/src/flags.rs | 6 +- crates/language_model/src/language_model.rs | 20 ++ .../language_models/src/provider/anthropic.rs | 4 + crates/language_models/src/provider/cloud.rs | 4 + crates/prompt_store/src/prompts.rs | 2 +- crates/settings/src/settings_content/agent.rs | 11 +- 13 files changed, 282 insertions(+), 145 deletions(-) diff --git a/assets/prompts/content_prompt_v2.hbs b/assets/prompts/content_prompt_v2.hbs index e1b6ddc6f023e9e97c9bb851473ac02e989c8feb..87376f49f12f0e27cc61e9f9747d9de6bfde43cb 100644 --- a/assets/prompts/content_prompt_v2.hbs +++ b/assets/prompts/content_prompt_v2.hbs @@ -39,6 +39,5 @@ Only make changes that are necessary to fulfill the prompt, leave everything els Start at the indentation level in the original file in the rewritten {{content_type}}. -You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if -you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so. +IMPORTANT: You MUST use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. You MUST NOT send back unstructured text. If you need to make a statement or ask a question you MUST use one of the tools to do so. It is an error if you try to make a change that cannot be made simply by editing the rewrite_section. diff --git a/assets/settings/default.json b/assets/settings/default.json index 58564138227f361e5432d377358b18734f250d72..a5180c9e2eaca9be49fa832e32e001d15d65df8f 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -896,6 +896,8 @@ "default_width": 380, }, "agent": { + // Whether the inline assistant should use streaming tools, when available + "inline_assistant_use_streaming_tools": true, // Whether the agent is enabled. "enabled": true, // What completion mode to start new threads in, if available. Can be 'normal' or 'burn'. diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 084ac7c3e7a1be4920126f857145e64b65a255dd..5dab085a255fe399d5f529791614d51f8b4cc78b 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -28,6 +28,7 @@ pub struct AgentSettings { pub default_height: Pixels, pub default_model: Option, pub inline_assistant_model: Option, + pub inline_assistant_use_streaming_tools: bool, pub commit_message_model: Option, pub thread_summary_model: Option, pub inline_alternatives: Vec, @@ -155,6 +156,9 @@ impl Settings for AgentSettings { default_height: px(agent.default_height.unwrap()), default_model: Some(agent.default_model.unwrap()), inline_assistant_model: agent.inline_assistant_model, + inline_assistant_use_streaming_tools: agent + .inline_assistant_use_streaming_tools + .unwrap_or(true), commit_message_model: agent.commit_message_model, thread_summary_model: agent.thread_summary_model, inline_alternatives: agent.inline_alternatives.unwrap_or_default(), diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index b6f7517ed934cf6cac8eefc262233b845169de9f..eb7785fad59894012251c84319af7fca306f2882 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -445,6 +445,7 @@ mod tests { default_height: px(600.), default_model: None, inline_assistant_model: None, + inline_assistant_use_streaming_tools: false, commit_message_model: None, thread_summary_model: None, inline_alternatives: vec![], diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 1cd7bec7b5b2c24cfbcf01a20091e8a07608e73a..e2c67a04167d7080a6f94b9ee2a8fae516d487d7 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -1,23 +1,26 @@ use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus}; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; + use client::telemetry::Telemetry; use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; -use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag}; +use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag}; use futures::{ SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::{LocalBoxFuture, Shared}, join, + stream::BoxStream, }; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role, - report_assistant_event, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice, + LanguageModelToolUse, Role, TokenUsage, report_assistant_event, }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; @@ -25,6 +28,7 @@ use prompt_store::PromptBuilder; use rope::Rope; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use settings::Settings as _; use smol::future::FutureExt; use std::{ cmp, @@ -46,6 +50,7 @@ pub struct FailureMessageInput { /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request. /// /// The message may use markdown formatting if you wish. + #[serde(default)] pub message: String, } @@ -56,9 +61,11 @@ pub struct RewriteSectionInput { /// /// The description may use markdown formatting if you wish. /// This is optional - if the edit is simple or obvious, you should leave it empty. + #[serde(default)] pub description: String, /// The text to replace the section with. + #[serde(default)] pub replacement_text: String, } @@ -379,6 +386,12 @@ impl CodegenAlternative { &self.last_equal_ranges } + fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool { + model.supports_streaming_tools() + && cx.has_flag::() + && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools + } + pub fn start( &mut self, user_prompt: String, @@ -398,11 +411,17 @@ impl CodegenAlternative { let telemetry_id = model.telemetry_id(); let provider_id = model.provider_id(); - if cx.has_flag::() { + if Self::use_streaming_tools(model.as_ref(), cx) { let request = self.build_request(&model, user_prompt, context_task, cx)?; - let tool_use = - cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await); - self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx); + let completion_events = + cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await); + self.generation = self.handle_completion( + telemetry_id, + provider_id.to_string(), + api_key, + completion_events, + cx, + ); } else { let stream: LocalBoxFuture> = if user_prompt.trim().to_lowercase() == "delete" { @@ -414,13 +433,14 @@ impl CodegenAlternative { }) .boxed_local() }; - self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); + self.generation = + self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); } Ok(()) } - fn build_request_v2( + fn build_request_tools( &self, model: &Arc, user_prompt: String, @@ -456,7 +476,7 @@ impl CodegenAlternative { let system_prompt = self .builder - .generate_inline_transformation_prompt_v2( + .generate_inline_transformation_prompt_tools( language_name, buffer, range.start.0..range.end.0, @@ -466,6 +486,9 @@ impl CodegenAlternative { let temperature = AgentSettings::temperature_for_model(model, cx); let tool_input_format = model.tool_input_format(); + let tool_choice = model + .supports_tool_choice(LanguageModelToolChoice::Any) + .then_some(LanguageModelToolChoice::Any); Ok(cx.spawn(async move |_cx| { let mut messages = vec![LanguageModelRequestMessage { @@ -508,7 +531,7 @@ impl CodegenAlternative { intent: Some(CompletionIntent::InlineAssist), mode: None, tools, - tool_choice: None, + tool_choice, stop: Vec::new(), temperature, messages, @@ -524,8 +547,8 @@ impl CodegenAlternative { context_task: Shared>>, cx: &mut App, ) -> Result> { - if cx.has_flag::() { - return self.build_request_v2(model, user_prompt, context_task, cx); + if Self::use_streaming_tools(model.as_ref(), cx) { + return self.build_request_tools(model, user_prompt, context_task, cx); } let buffer = self.buffer.read(cx).snapshot(cx); @@ -603,7 +626,7 @@ impl CodegenAlternative { model_api_key: Option, stream: impl 'static + Future>, cx: &mut Context, - ) { + ) -> Task<()> { let start_time = Instant::now(); // Make a new snapshot and re-resolve anchor in case the document was modified. @@ -659,7 +682,8 @@ impl CodegenAlternative { let completion = Arc::new(Mutex::new(String::new())); let completion_clone = completion.clone(); - self.generation = cx.spawn(async move |codegen, cx| { + cx.notify(); + cx.spawn(async move |codegen, cx| { let stream = stream.await; let token_usage = stream @@ -685,6 +709,7 @@ impl CodegenAlternative { stream?.stream.map_err(|error| error.into()), ); futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); let mut line_diff = LineDiff::default(); @@ -876,8 +901,7 @@ impl CodegenAlternative { cx.notify(); }) .ok(); - }); - cx.notify(); + }) } pub fn current_completion(&self) -> Option { @@ -1060,21 +1084,29 @@ impl CodegenAlternative { }) } - fn handle_tool_use( + fn handle_completion( &mut self, - _telemetry_id: String, - _provider_id: String, - _api_key: Option, - tool_use: impl 'static - + Future< - Output = Result, + telemetry_id: String, + provider_id: String, + api_key: Option, + completion_stream: Task< + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, >, cx: &mut Context, - ) { + ) -> Task<()> { self.diff = Diff::default(); self.status = CodegenStatus::Pending; - self.generation = cx.spawn(async move |codegen, cx| { + cx.notify(); + // Leaving this in generation so that STOP equivalent events are respected even + // while we're still pre-processing the completion event + cx.spawn(async move |codegen, cx| { let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| { let _ = codegen.update(cx, |this, cx| { this.status = status; @@ -1083,76 +1115,176 @@ impl CodegenAlternative { }); }; - let tool_use = tool_use.await; - - match tool_use { - Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => { - // Parse the input JSON into RewriteSectionInput - match serde_json::from_value::(tool_use.input) { - Ok(input) => { - // Store the description if non-empty - let description = if !input.description.trim().is_empty() { - Some(input.description.clone()) - } else { - None + let mut completion_events = match completion_stream.await { + Ok(events) => events, + Err(err) => { + finish_with_status(CodegenStatus::Error(err.into()), cx); + return; + } + }; + + let chars_read_so_far = Arc::new(Mutex::new(0usize)); + let tool_to_text_and_message = + move |tool_use: LanguageModelToolUse| -> (Option, Option) { + let mut chars_read_so_far = chars_read_so_far.lock(); + match tool_use.name.as_ref() { + "rewrite_section" => { + let Ok(mut input) = + serde_json::from_value::(tool_use.input) + else { + return (None, None); }; + let value = input.replacement_text[*chars_read_so_far..].to_string(); + *chars_read_so_far = input.replacement_text.len(); + (Some(value), Some(std::mem::take(&mut input.description))) + } + "failure_message" => { + let Ok(mut input) = + serde_json::from_value::(tool_use.input) + else { + return (None, None); + }; + (None, Some(std::mem::take(&mut input.message))) + } + _ => (None, None), + } + }; - // Apply the replacement text to the buffer and compute diff - let batch_diff_task = codegen - .update(cx, |this, cx| { - this.model_explanation = description.map(Into::into); - let range = this.range.clone(); - this.apply_edits( - std::iter::once((range, input.replacement_text)), - cx, - ); - this.reapply_batch_diff(cx) - }) - .ok(); - - // Wait for the diff computation to complete - if let Some(diff_task) = batch_diff_task { - diff_task.await; - } + let mut message_id = None; + let mut first_text = None; + let last_token_usage = Arc::new(Mutex::new(TokenUsage::default())); + let total_text = Arc::new(Mutex::new(String::new())); - finish_with_status(CodegenStatus::Done, cx); - return; + loop { + if let Some(first_event) = completion_events.next().await { + match first_event { + Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => { + message_id = Some(id); } - Err(e) => { - finish_with_status(CodegenStatus::Error(e.into()), cx); - return; + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) + if matches!( + tool_use.name.as_ref(), + "rewrite_section" | "failure_message" + ) => + { + let is_complete = tool_use.is_input_complete; + let (text, message) = tool_to_text_and_message(tool_use); + // Only update the model explanation if the tool use is complete. + // Otherwise the UI element bounces around as it's updated. + if is_complete { + let _ = codegen.update(cx, |this, _cx| { + this.model_explanation = message.map(Into::into); + }); + } + first_text = text; + if first_text.is_some() { + break; + } } - } - } - Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => { - // Handle failure message tool use - match serde_json::from_value::(tool_use.input) { - Ok(input) => { - let _ = codegen.update(cx, |this, _cx| { - // Store the failure message as the tool description - this.model_explanation = Some(input.message.into()); - }); - finish_with_status(CodegenStatus::Done, cx); - return; + Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { + *last_token_usage.lock() = token_usage; + } + Ok(LanguageModelCompletionEvent::Text(text)) => { + let mut lock = total_text.lock(); + lock.push_str(&text); + } + Ok(e) => { + log::warn!("Unexpected event: {:?}", e); + break; } Err(e) => { finish_with_status(CodegenStatus::Error(e.into()), cx); - return; + break; } } } - Ok(_tool_use) => { - // Unexpected tool. - finish_with_status(CodegenStatus::Done, cx); - return; - } - Err(e) => { - finish_with_status(CodegenStatus::Error(e.into()), cx); - return; - } } - }); - cx.notify(); + + let Some(first_text) = first_text else { + finish_with_status(CodegenStatus::Done, cx); + return; + }; + + let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded(); + + cx.spawn({ + let codegen = codegen.clone(); + async move |cx| { + while let Some(message) = message_rx.next().await { + let _ = codegen.update(cx, |this, _cx| { + this.model_explanation = message; + }); + } + } + }) + .detach(); + + let move_last_token_usage = last_token_usage.clone(); + + let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain( + completion_events.filter_map(move |e| { + let tool_to_text_and_message = tool_to_text_and_message.clone(); + let last_token_usage = move_last_token_usage.clone(); + let total_text = total_text.clone(); + let mut message_tx = message_tx.clone(); + async move { + match e { + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) + if matches!( + tool_use.name.as_ref(), + "rewrite_section" | "failure_message" + ) => + { + let is_complete = tool_use.is_input_complete; + let (text, message) = tool_to_text_and_message(tool_use); + if is_complete { + // Again only send the message when complete to not get a bouncing UI element. + let _ = message_tx.send(message.map(Into::into)).await; + } + text.map(Ok) + } + Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { + *last_token_usage.lock() = token_usage; + None + } + Ok(LanguageModelCompletionEvent::Text(text)) => { + let mut lock = total_text.lock(); + lock.push_str(&text); + None + } + Ok(LanguageModelCompletionEvent::Stop(_reason)) => None, + e => { + log::error!("UNEXPECTED EVENT {:?}", e); + None + } + } + } + }), + )); + + let language_model_text_stream = LanguageModelTextStream { + message_id: message_id, + stream: text_stream, + last_token_usage, + }; + + let Some(task) = codegen + .update(cx, move |codegen, cx| { + codegen.handle_stream( + telemetry_id, + provider_id, + api_key, + async { Ok(language_model_text_stream) }, + cx, + ) + }) + .ok() + else { + return; + }; + + task.await; + }) } } @@ -1679,7 +1811,7 @@ mod tests { ) -> mpsc::UnboundedSender { let (chunks_tx, chunks_rx) = mpsc::unbounded(); codegen.update(cx, |codegen, cx| { - codegen.handle_stream( + codegen.generation = codegen.handle_stream( String::new(), String::new(), None, diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 48da85d38554da8227d76d3cbe290e29ef4fc531..ad0f58c162ca720e619e83ca9a3eb65a4be9fe2b 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -1455,60 +1455,8 @@ impl InlineAssistant { let old_snapshot = codegen.snapshot(cx); let old_buffer = codegen.old_buffer(cx); let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone(); - // let model_explanation = codegen.model_explanation(cx); editor.update(cx, |editor, cx| { - // Update tool description block - // if let Some(description) = model_explanation { - // if let Some(block_id) = decorations.model_explanation { - // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); - // let new_block_id = editor.insert_blocks( - // [BlockProperties { - // style: BlockStyle::Flex, - // placement: BlockPlacement::Below(assist.range.end), - // height: Some(1), - // render: Arc::new({ - // let description = description.clone(); - // move |cx| { - // div() - // .w_full() - // .py_1() - // .px_2() - // .bg(cx.theme().colors().editor_background) - // .border_y_1() - // .border_color(cx.theme().status().info_border) - // .child( - // Label::new(description.clone()) - // .color(Color::Muted) - // .size(LabelSize::Small), - // ) - // .into_any_element() - // } - // }), - // priority: 0, - // }], - // None, - // cx, - // ); - // decorations.model_explanation = new_block_id.into_iter().next(); - // } - // } else if let Some(block_id) = decorations.model_explanation { - // // Hide the block if there's no description - // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); - // let new_block_id = editor.insert_blocks( - // [BlockProperties { - // style: BlockStyle::Flex, - // placement: BlockPlacement::Below(assist.range.end), - // height: Some(0), - // render: Arc::new(|_cx| div().into_any_element()), - // priority: 0, - // }], - // None, - // cx, - // ); - // decorations.model_explanation = new_block_id.into_iter().next(); - // } - let old_blocks = mem::take(&mut decorations.removed_line_block_ids); editor.remove_blocks(old_blocks, None, cx); diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 09b293b122624274b7484026f35d1bcc8e265ece..e976b7f5dc36905d2a32b4cdc04869f3267705fe 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -429,10 +429,24 @@ impl Model { let mut headers = vec![]; match self { + Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_5 + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4_5 + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking + | Self::ClaudeOpus4_5Thinking + | Self::ClaudeSonnet4Thinking + | Self::ClaudeSonnet4_5Thinking => { + // Fine-grained tool streaming for newer models + headers.push("fine-grained-tool-streaming-2025-05-14".to_string()); + } Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => { // Try beta token-efficient tool use (supported in Claude 3.7 Sonnet only) // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use headers.push("token-efficient-tools-2025-02-19".to_string()); + headers.push("fine-grained-tool-streaming-2025-05-14".to_string()); } Self::Custom { extra_beta_headers, .. diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index 61d9a34e38de546c79a2dbb5f889e2fddad38480..566d5604149567702e8739d2f3ac9fdc6f5f0de8 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -12,10 +12,10 @@ impl FeatureFlag for PanicFeatureFlag { const NAME: &'static str = "panic"; } -pub struct InlineAssistantV2FeatureFlag; +pub struct InlineAssistantUseToolFeatureFlag; -impl FeatureFlag for InlineAssistantV2FeatureFlag { - const NAME: &'static str = "inline-assistant-v2"; +impl FeatureFlag for InlineAssistantUseToolFeatureFlag { + const NAME: &'static str = "inline-assistant-use-tool"; fn enabled_for_staff() -> bool { false diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index e158bb256be42291549c2379ae7ec19402166543..09d44b5b408324936af00a2a5e4f1deb4f351434 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -612,6 +612,11 @@ pub trait LanguageModel: Send + Sync { false } + /// Returns whether this model or provider supports streaming tool calls; + fn supports_streaming_tools(&self) -> bool { + false + } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { LanguageModelToolSchemaFormat::JsonSchema } @@ -766,6 +771,21 @@ pub trait LanguageModelExt: LanguageModel { } impl LanguageModelExt for dyn LanguageModel {} +impl std::fmt::Debug for dyn LanguageModel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("") + .field("id", &self.id()) + .field("name", &self.name()) + .field("provider_id", &self.provider_id()) + .field("provider_name", &self.provider_name()) + .field("upstream_provider_name", &self.upstream_provider_name()) + .field("upstream_provider_id", &self.upstream_provider_id()) + .field("upstream_provider_id", &self.upstream_provider_id()) + .field("supports_streaming_tools", &self.supports_streaming_tools()) + .finish() + } +} + /// An error that occurred when trying to authenticate the language model provider. #[derive(Debug, Error)] pub enum AuthenticateError { diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index f9e1e60cf648d3a67cec425ebd1f09ad7b564665..25ba7615dc23e2561648e173588be6d93c28e295 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -350,6 +350,10 @@ impl LanguageModel for AnthropicModel { true } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index a19a427dbacb32883b1877888ec04899a2b8d427..508a77d38abcf2143170382e945ab6ce31f3a623 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -602,6 +602,10 @@ impl LanguageModel for CloudLanguageModel { self.model.supports_images } + fn supports_streaming_tools(&self) -> bool { + self.model.supports_streaming_tools + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index d6a172218a8eb3d4538363e6202a7e721d2b7bc1..847e45742db17fe194d002c26a67380390b68f06 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -286,7 +286,7 @@ impl PromptBuilder { Ok(()) } - pub fn generate_inline_transformation_prompt_v2( + pub fn generate_inline_transformation_prompt_tools( &self, language_name: Option<&LanguageName>, buffer: BufferSnapshot, diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index 2ea9f0cd5788f3312061ec8ffef2a728403463ac..fccc3e09fceb8e05ad3494101a4d23d95257358e 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -36,7 +36,13 @@ pub struct AgentSettingsContent { pub default_model: Option, /// Model to use for the inline assistant. Defaults to default_model when not specified. pub inline_assistant_model: Option, - /// Model to use for generating git commit messages. Defaults to default_model when not specified. + /// Model to use for the inline assistant when streaming tools are enabled. + /// + /// Default: true + pub inline_assistant_use_streaming_tools: Option, + /// Model to use for generating git commit messages. + /// + /// Default: true pub commit_message_model: Option, /// Model to use for generating thread summaries. Defaults to default_model when not specified. pub thread_summary_model: Option, @@ -129,6 +135,9 @@ impl AgentSettingsContent { model, }); } + pub fn set_inline_assistant_use_streaming_tools(&mut self, use_tools: bool) { + self.inline_assistant_use_streaming_tools = Some(use_tools); + } pub fn set_commit_message_model(&mut self, provider: String, model: String) { self.commit_message_model = Some(LanguageModelSelection {