@@ -1,23 +1,26 @@
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
+
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
-use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
+use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{LocalBoxFuture, Shared},
join,
+ stream::BoxStream,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
- LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
- report_assistant_event,
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
+ LanguageModelToolUse, Role, TokenUsage, report_assistant_event,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
@@ -25,6 +28,7 @@ use prompt_store::PromptBuilder;
use rope::Rope;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
+use settings::Settings as _;
use smol::future::FutureExt;
use std::{
cmp,
@@ -46,6 +50,7 @@ pub struct FailureMessageInput {
/// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
///
/// The message may use markdown formatting if you wish.
+ #[serde(default)]
pub message: String,
}
@@ -56,9 +61,11 @@ pub struct RewriteSectionInput {
///
/// The description may use markdown formatting if you wish.
/// This is optional - if the edit is simple or obvious, you should leave it empty.
+ #[serde(default)]
pub description: String,
/// The text to replace the section with.
+ #[serde(default)]
pub replacement_text: String,
}
@@ -379,6 +386,12 @@ impl CodegenAlternative {
&self.last_equal_ranges
}
+ fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
+ model.supports_streaming_tools()
+ && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
+ && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
+ }
+
pub fn start(
&mut self,
user_prompt: String,
@@ -398,11 +411,17 @@ impl CodegenAlternative {
let telemetry_id = model.telemetry_id();
let provider_id = model.provider_id();
- if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
+ if Self::use_streaming_tools(model.as_ref(), cx) {
let request = self.build_request(&model, user_prompt, context_task, cx)?;
- let tool_use =
- cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
- self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
+ let completion_events =
+ cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
+ self.generation = self.handle_completion(
+ telemetry_id,
+ provider_id.to_string(),
+ api_key,
+ completion_events,
+ cx,
+ );
} else {
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
if user_prompt.trim().to_lowercase() == "delete" {
@@ -414,13 +433,14 @@ impl CodegenAlternative {
})
.boxed_local()
};
- self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
+ self.generation =
+ self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
}
Ok(())
}
- fn build_request_v2(
+ fn build_request_tools(
&self,
model: &Arc<dyn LanguageModel>,
user_prompt: String,
@@ -456,7 +476,7 @@ impl CodegenAlternative {
let system_prompt = self
.builder
- .generate_inline_transformation_prompt_v2(
+ .generate_inline_transformation_prompt_tools(
language_name,
buffer,
range.start.0..range.end.0,
@@ -466,6 +486,9 @@ impl CodegenAlternative {
let temperature = AgentSettings::temperature_for_model(model, cx);
let tool_input_format = model.tool_input_format();
+ let tool_choice = model
+ .supports_tool_choice(LanguageModelToolChoice::Any)
+ .then_some(LanguageModelToolChoice::Any);
Ok(cx.spawn(async move |_cx| {
let mut messages = vec![LanguageModelRequestMessage {
@@ -508,7 +531,7 @@ impl CodegenAlternative {
intent: Some(CompletionIntent::InlineAssist),
mode: None,
tools,
- tool_choice: None,
+ tool_choice,
stop: Vec::new(),
temperature,
messages,
@@ -524,8 +547,8 @@ impl CodegenAlternative {
context_task: Shared<Task<Option<LoadedContext>>>,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
- if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
- return self.build_request_v2(model, user_prompt, context_task, cx);
+ if Self::use_streaming_tools(model.as_ref(), cx) {
+ return self.build_request_tools(model, user_prompt, context_task, cx);
}
let buffer = self.buffer.read(cx).snapshot(cx);
@@ -603,7 +626,7 @@ impl CodegenAlternative {
model_api_key: Option<String>,
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
cx: &mut Context<Self>,
- ) {
+ ) -> Task<()> {
let start_time = Instant::now();
// Make a new snapshot and re-resolve anchor in case the document was modified.
@@ -659,7 +682,8 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
- self.generation = cx.spawn(async move |codegen, cx| {
+ cx.notify();
+ cx.spawn(async move |codegen, cx| {
let stream = stream.await;
let token_usage = stream
@@ -685,6 +709,7 @@ impl CodegenAlternative {
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
+
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -876,8 +901,7 @@ impl CodegenAlternative {
cx.notify();
})
.ok();
- });
- cx.notify();
+ })
}
pub fn current_completion(&self) -> Option<String> {
@@ -1060,21 +1084,29 @@ impl CodegenAlternative {
})
}
- fn handle_tool_use(
+ fn handle_completion(
&mut self,
- _telemetry_id: String,
- _provider_id: String,
- _api_key: Option<String>,
- tool_use: impl 'static
- + Future<
- Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
+ telemetry_id: String,
+ provider_id: String,
+ api_key: Option<String>,
+ completion_stream: Task<
+ Result<
+ BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ LanguageModelCompletionError,
+ >,
>,
cx: &mut Context<Self>,
- ) {
+ ) -> Task<()> {
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
- self.generation = cx.spawn(async move |codegen, cx| {
+ cx.notify();
+ // Leaving this in generation so that STOP equivalent events are respected even
+ // while we're still pre-processing the completion event
+ cx.spawn(async move |codegen, cx| {
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
let _ = codegen.update(cx, |this, cx| {
this.status = status;
@@ -1083,76 +1115,176 @@ impl CodegenAlternative {
});
};
- let tool_use = tool_use.await;
-
- match tool_use {
- Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
- // Parse the input JSON into RewriteSectionInput
- match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
- Ok(input) => {
- // Store the description if non-empty
- let description = if !input.description.trim().is_empty() {
- Some(input.description.clone())
- } else {
- None
+ let mut completion_events = match completion_stream.await {
+ Ok(events) => events,
+ Err(err) => {
+ finish_with_status(CodegenStatus::Error(err.into()), cx);
+ return;
+ }
+ };
+
+ let chars_read_so_far = Arc::new(Mutex::new(0usize));
+ let tool_to_text_and_message =
+ move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
+ let mut chars_read_so_far = chars_read_so_far.lock();
+ match tool_use.name.as_ref() {
+ "rewrite_section" => {
+ let Ok(mut input) =
+ serde_json::from_value::<RewriteSectionInput>(tool_use.input)
+ else {
+ return (None, None);
};
+ let value = input.replacement_text[*chars_read_so_far..].to_string();
+ *chars_read_so_far = input.replacement_text.len();
+ (Some(value), Some(std::mem::take(&mut input.description)))
+ }
+ "failure_message" => {
+ let Ok(mut input) =
+ serde_json::from_value::<FailureMessageInput>(tool_use.input)
+ else {
+ return (None, None);
+ };
+ (None, Some(std::mem::take(&mut input.message)))
+ }
+ _ => (None, None),
+ }
+ };
- // Apply the replacement text to the buffer and compute diff
- let batch_diff_task = codegen
- .update(cx, |this, cx| {
- this.model_explanation = description.map(Into::into);
- let range = this.range.clone();
- this.apply_edits(
- std::iter::once((range, input.replacement_text)),
- cx,
- );
- this.reapply_batch_diff(cx)
- })
- .ok();
-
- // Wait for the diff computation to complete
- if let Some(diff_task) = batch_diff_task {
- diff_task.await;
- }
+ let mut message_id = None;
+ let mut first_text = None;
+ let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
+ let total_text = Arc::new(Mutex::new(String::new()));
- finish_with_status(CodegenStatus::Done, cx);
- return;
+ loop {
+ if let Some(first_event) = completion_events.next().await {
+ match first_event {
+ Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
+ message_id = Some(id);
}
- Err(e) => {
- finish_with_status(CodegenStatus::Error(e.into()), cx);
- return;
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
+ if matches!(
+ tool_use.name.as_ref(),
+ "rewrite_section" | "failure_message"
+ ) =>
+ {
+ let is_complete = tool_use.is_input_complete;
+ let (text, message) = tool_to_text_and_message(tool_use);
+ // Only update the model explanation if the tool use is complete.
+ // Otherwise the UI element bounces around as it's updated.
+ if is_complete {
+ let _ = codegen.update(cx, |this, _cx| {
+ this.model_explanation = message.map(Into::into);
+ });
+ }
+ first_text = text;
+ if first_text.is_some() {
+ break;
+ }
}
- }
- }
- Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
- // Handle failure message tool use
- match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
- Ok(input) => {
- let _ = codegen.update(cx, |this, _cx| {
- // Store the failure message as the tool description
- this.model_explanation = Some(input.message.into());
- });
- finish_with_status(CodegenStatus::Done, cx);
- return;
+ Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
+ *last_token_usage.lock() = token_usage;
+ }
+ Ok(LanguageModelCompletionEvent::Text(text)) => {
+ let mut lock = total_text.lock();
+ lock.push_str(&text);
+ }
+ Ok(e) => {
+ log::warn!("Unexpected event: {:?}", e);
+ break;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
- return;
+ break;
}
}
}
- Ok(_tool_use) => {
- // Unexpected tool.
- finish_with_status(CodegenStatus::Done, cx);
- return;
- }
- Err(e) => {
- finish_with_status(CodegenStatus::Error(e.into()), cx);
- return;
- }
}
- });
- cx.notify();
+
+ let Some(first_text) = first_text else {
+ finish_with_status(CodegenStatus::Done, cx);
+ return;
+ };
+
+ let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
+
+ cx.spawn({
+ let codegen = codegen.clone();
+ async move |cx| {
+ while let Some(message) = message_rx.next().await {
+ let _ = codegen.update(cx, |this, _cx| {
+ this.model_explanation = message;
+ });
+ }
+ }
+ })
+ .detach();
+
+ let move_last_token_usage = last_token_usage.clone();
+
+ let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
+ completion_events.filter_map(move |e| {
+ let tool_to_text_and_message = tool_to_text_and_message.clone();
+ let last_token_usage = move_last_token_usage.clone();
+ let total_text = total_text.clone();
+ let mut message_tx = message_tx.clone();
+ async move {
+ match e {
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
+ if matches!(
+ tool_use.name.as_ref(),
+ "rewrite_section" | "failure_message"
+ ) =>
+ {
+ let is_complete = tool_use.is_input_complete;
+ let (text, message) = tool_to_text_and_message(tool_use);
+ if is_complete {
+ // Again only send the message when complete to not get a bouncing UI element.
+ let _ = message_tx.send(message.map(Into::into)).await;
+ }
+ text.map(Ok)
+ }
+ Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
+ *last_token_usage.lock() = token_usage;
+ None
+ }
+ Ok(LanguageModelCompletionEvent::Text(text)) => {
+ let mut lock = total_text.lock();
+ lock.push_str(&text);
+ None
+ }
+ Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
+ e => {
+ log::error!("UNEXPECTED EVENT {:?}", e);
+ None
+ }
+ }
+ }
+ }),
+ ));
+
+ let language_model_text_stream = LanguageModelTextStream {
+ message_id: message_id,
+ stream: text_stream,
+ last_token_usage,
+ };
+
+ let Some(task) = codegen
+ .update(cx, move |codegen, cx| {
+ codegen.handle_stream(
+ telemetry_id,
+ provider_id,
+ api_key,
+ async { Ok(language_model_text_stream) },
+ cx,
+ )
+ })
+ .ok()
+ else {
+ return;
+ };
+
+ task.await;
+ })
}
}
@@ -1679,7 +1811,7 @@ mod tests {
) -> mpsc::UnboundedSender<String> {
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
- codegen.handle_stream(
+ codegen.generation = codegen.handle_stream(
String::new(),
String::new(),
None,