From a23e381096c623951212608119fc497101e281f1 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 10 Sep 2024 15:25:57 -0400 Subject: [PATCH] assistant: Pass up tool results in LLM request messages (#17656) This PR makes it so we pass up the tool results in the `tool_results` field in the request message to the LLM. This required reworking how we track non-text content in the context editor. We also removed serialization of images in context history, as we were never deserializing it, and thus it was unneeded. Release Notes: - N/A --------- Co-authored-by: Antonio --- crates/assistant/src/assistant_panel.rs | 39 ++- crates/assistant/src/context.rs | 381 +++++++++++------------- crates/paths/src/paths.rs | 6 - 3 files changed, 215 insertions(+), 211 deletions(-) diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 82888b498a3c8019e51771075e0b53367ed0eb82..22843d41cd620736a9e62b11091613f16b2f369a 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -11,7 +11,7 @@ use crate::{ }, slash_command_picker, terminal_inline_assistant::TerminalInlineAssistant, - Assist, CacheStatus, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, + Assist, CacheStatus, ConfirmCommand, Content, Context, ContextEvent, ContextId, ContextStore, ContextStoreEvent, CycleMessageRole, DeployHistory, DeployPromptLibrary, InlineAssistId, InlineAssistant, InsertDraggedFiles, InsertIntoEditor, Message, MessageId, MessageMetadata, MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, PendingSlashCommand, @@ -46,6 +46,7 @@ use indexed_docs::IndexedDocsStore; use language::{ language_settings::SoftWrap, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset, }; +use language_model::LanguageModelToolUse; use language_model::{ provider::cloud::PROVIDER_ID, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, Role, @@ -1995,6 +1996,20 @@ impl ContextEditor { let buffer_row = MultiBufferRow(start.to_point(&buffer).row); buffer_rows_to_fold.insert(buffer_row); + self.context.update(cx, |context, cx| { + context.insert_content( + Content::ToolUse { + range: tool_use.source_range.clone(), + tool_use: LanguageModelToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.clone(), + input: tool_use.input.clone(), + }, + }, + cx, + ); + }); + Crease::new( start..end, placeholder, @@ -3538,7 +3553,7 @@ impl ContextEditor { let image_id = image.id(); context.insert_image(image, cx); for image_position in image_positions.iter() { - context.insert_image_anchor(image_id, image_position.text_anchor, cx); + context.insert_image_content(image_id, image_position.text_anchor, cx); } } }); @@ -3553,11 +3568,23 @@ impl ContextEditor { let new_blocks = self .context .read(cx) - .images(cx) - .filter_map(|image| { + .contents(cx) + .filter_map(|content| { + if let Content::Image { + anchor, + render_image, + .. + } = content + { + Some((anchor, render_image)) + } else { + None + } + }) + .filter_map(|(anchor, render_image)| { const MAX_HEIGHT_IN_LINES: u32 = 8; - let anchor = buffer.anchor_in_excerpt(excerpt_id, image.anchor).unwrap(); - let image = image.render_image.clone(); + let anchor = buffer.anchor_in_excerpt(excerpt_id, anchor).unwrap(); + let image = render_image.clone(); anchor.is_valid(&buffer).then(|| BlockProperties { position: anchor, height: MAX_HEIGHT_IN_LINES, diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 1bf846369b762750bfada251ac7f3df15c1d84f5..e43ec203e9675278e29360118c94ca32fe0cffed 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -17,7 +17,6 @@ use feature_flags::{FeatureFlag, FeatureFlagAppExt}; use fs::{Fs, RemoveOptions}; use futures::{ future::{self, Shared}, - stream::FuturesUnordered, FutureExt, StreamExt, }; use gpui::{ @@ -29,10 +28,11 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P use language_model::{ LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, MessageContent, Role, StopReason, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, + StopReason, }; use open_ai::Model as OpenAiModel; -use paths::{context_images_dir, contexts_dir}; +use paths::contexts_dir; use project::Project; use serde::{Deserialize, Serialize}; use smallvec::SmallVec; @@ -377,23 +377,8 @@ impl MessageMetadata { } } -#[derive(Clone, Debug)] -pub struct MessageImage { - image_id: u64, - image: Shared>>, -} - -impl PartialEq for MessageImage { - fn eq(&self, other: &Self) -> bool { - self.image_id == other.image_id - } -} - -impl Eq for MessageImage {} - #[derive(Clone, Debug)] pub struct Message { - pub image_offsets: SmallVec<[(usize, MessageImage); 1]>, pub offset_range: Range, pub index_range: Range, pub anchor_range: Range, @@ -403,62 +388,45 @@ pub struct Message { pub cache: Option, } -impl Message { - fn to_request_message(&self, buffer: &Buffer) -> Option { - let mut content = Vec::new(); - - let mut range_start = self.offset_range.start; - for (image_offset, message_image) in self.image_offsets.iter() { - if *image_offset != range_start { - if let Some(text) = Self::collect_text_content(buffer, range_start..*image_offset) { - content.push(text); - } - } - - if let Some(image) = message_image.image.clone().now_or_never().flatten() { - content.push(language_model::MessageContent::Image(image)); - } - - range_start = *image_offset; - } - - if range_start != self.offset_range.end { - if let Some(text) = - Self::collect_text_content(buffer, range_start..self.offset_range.end) - { - content.push(text); - } - } +#[derive(Debug, Clone)] +pub enum Content { + Image { + anchor: language::Anchor, + image_id: u64, + render_image: Arc, + image: Shared>>, + }, + ToolUse { + range: Range, + tool_use: LanguageModelToolUse, + }, + ToolResult { + range: Range, + tool_use_id: Arc, + }, +} - if content.is_empty() { - return None; +impl Content { + fn range(&self) -> Range { + match self { + Self::Image { anchor, .. } => *anchor..*anchor, + Self::ToolUse { range, .. } | Self::ToolResult { range, .. } => range.clone(), } - - Some(LanguageModelRequestMessage { - role: self.role, - content, - cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor), - }) } - fn collect_text_content(buffer: &Buffer, range: Range) -> Option { - let text: String = buffer.text_for_range(range.clone()).collect(); - if text.trim().is_empty() { - None + fn cmp(&self, other: &Self, buffer: &BufferSnapshot) -> Ordering { + let self_range = self.range(); + let other_range = other.range(); + if self_range.end.cmp(&other_range.start, buffer).is_lt() { + Ordering::Less + } else if self_range.start.cmp(&other_range.end, buffer).is_gt() { + Ordering::Greater } else { - Some(MessageContent::Text(text)) + Ordering::Equal } } } -#[derive(Clone, Debug)] -pub struct ImageAnchor { - pub anchor: language::Anchor, - pub image_id: u64, - pub render_image: Arc, - pub image: Shared>>, -} - struct PendingCompletion { id: usize, assistant_message_id: MessageId, @@ -501,7 +469,7 @@ pub struct Context { pending_tool_uses_by_id: HashMap, PendingToolUse>, message_anchors: Vec, images: HashMap, Shared>>)>, - image_anchors: Vec, + contents: Vec, messages_metadata: HashMap, summary: Option, pending_summary: Task>, @@ -595,7 +563,7 @@ impl Context { pending_ops: Vec::new(), operations: Vec::new(), message_anchors: Default::default(), - image_anchors: Default::default(), + contents: Default::default(), images: Default::default(), messages_metadata: Default::default(), pending_slash_commands: Vec::new(), @@ -659,11 +627,6 @@ impl Context { id: message.id, start: message.offset_range.start, metadata: self.messages_metadata[&message.id].clone(), - image_offsets: message - .image_offsets - .iter() - .map(|image_offset| (image_offset.0, image_offset.1.image_id)) - .collect(), }) .collect(), summary: self @@ -1957,6 +1920,14 @@ impl Context { output_range }); + this.insert_content( + Content::ToolResult { + range: anchor_range.clone(), + tool_use_id: tool_use_id.clone(), + }, + cx, + ); + cx.emit(ContextEvent::ToolFinished { tool_use_id, output_range: anchor_range, @@ -2038,6 +2009,7 @@ impl Context { let stream_completion = async { let request_start = Instant::now(); let mut events = stream.await?; + let mut stop_reason = StopReason::EndTurn; while let Some(event) = events.next().await { if response_latency.is_none() { @@ -2050,7 +2022,7 @@ impl Context { .message_anchors .iter() .position(|message| message.id == assistant_message_id)?; - let event_to_emit = this.buffer.update(cx, |buffer, cx| { + this.buffer.update(cx, |buffer, cx| { let message_old_end_offset = this.message_anchors[message_ix + 1..] .iter() .find(|message| message.start.is_valid(buffer)) @@ -2059,13 +2031,9 @@ impl Context { }); match event { - LanguageModelCompletionEvent::Stop(reason) => match reason { - StopReason::ToolUse => { - return Some(ContextEvent::UsePendingTools); - } - StopReason::EndTurn => {} - StopReason::MaxTokens => {} - }, + LanguageModelCompletionEvent::Stop(reason) => { + stop_reason = reason; + } LanguageModelCompletionEvent::Text(chunk) => { buffer.edit( [( @@ -2116,14 +2084,9 @@ impl Context { ); } } - - None }); cx.emit(ContextEvent::StreamedCompletion); - if let Some(event) = event_to_emit { - cx.emit(event); - } Some(()) })?; @@ -2136,13 +2099,14 @@ impl Context { this.update_cache_status_for_completion(cx); })?; - anyhow::Ok(()) + anyhow::Ok(stop_reason) }; let result = stream_completion.await; this.update(&mut cx, |this, cx| { let error_message = result + .as_ref() .err() .map(|error| error.to_string().trim().to_string()); @@ -2170,6 +2134,16 @@ impl Context { error_message, ); } + + if let Ok(stop_reason) = result { + match stop_reason { + StopReason::ToolUse => { + cx.emit(ContextEvent::UsePendingTools); + } + StopReason::EndTurn => {} + StopReason::MaxTokens => {} + } + } }) .ok(); } @@ -2186,18 +2160,94 @@ impl Context { pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest { let buffer = self.buffer.read(cx); - let request_messages = self - .messages(cx) - .filter(|message| message.status == MessageStatus::Done) - .filter_map(|message| message.to_request_message(&buffer)) - .collect(); - LanguageModelRequest { - messages: request_messages, + let mut contents = self.contents(cx).peekable(); + + fn collect_text_content(buffer: &Buffer, range: Range) -> Option { + let text: String = buffer.text_for_range(range.clone()).collect(); + if text.trim().is_empty() { + None + } else { + Some(text) + } + } + + let mut completion_request = LanguageModelRequest { + messages: Vec::new(), tools: Vec::new(), stop: Vec::new(), temperature: 1.0, + }; + for message in self.messages(cx) { + if message.status != MessageStatus::Done { + continue; + } + + let mut offset = message.offset_range.start; + let mut request_message = LanguageModelRequestMessage { + role: message.role, + content: Vec::new(), + cache: message + .cache + .as_ref() + .map_or(false, |cache| cache.is_anchor), + }; + + while let Some(content) = contents.peek() { + if content + .range() + .end + .cmp(&message.anchor_range.end, buffer) + .is_lt() + { + let content = contents.next().unwrap(); + let range = content.range().to_offset(buffer); + request_message.content.extend( + collect_text_content(buffer, offset..range.start).map(MessageContent::Text), + ); + + match content { + Content::Image { image, .. } => { + if let Some(image) = image.clone().now_or_never().flatten() { + request_message + .content + .push(language_model::MessageContent::Image(image)); + } + } + Content::ToolUse { tool_use, .. } => { + request_message + .content + .push(language_model::MessageContent::ToolUse(tool_use.clone())); + } + Content::ToolResult { tool_use_id, .. } => { + request_message.content.push( + language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_use_id.to_string(), + is_error: false, + content: collect_text_content(buffer, range.clone()) + .unwrap_or_default(), + }, + ), + ); + } + } + + offset = range.end; + } else { + break; + } + } + + request_message.content.extend( + collect_text_content(buffer, offset..message.offset_range.end) + .map(MessageContent::Text), + ); + + completion_request.messages.push(request_message); } + + completion_request } pub fn cancel_last_assist(&mut self, cx: &mut ModelContext) -> bool { @@ -2335,42 +2385,50 @@ impl Context { Some(()) } - pub fn insert_image_anchor( + pub fn insert_image_content( &mut self, image_id: u64, anchor: language::Anchor, cx: &mut ModelContext, - ) -> bool { - cx.emit(ContextEvent::MessagesEdited); - - let buffer = self.buffer.read(cx); - let insertion_ix = match self - .image_anchors - .binary_search_by(|existing_anchor| anchor.cmp(&existing_anchor.anchor, buffer)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - + ) { if let Some((render_image, image)) = self.images.get(&image_id) { - self.image_anchors.insert( - insertion_ix, - ImageAnchor { + self.insert_content( + Content::Image { anchor, image_id, image: image.clone(), render_image: render_image.clone(), }, + cx, ); - - true - } else { - false } } - pub fn images<'a>(&'a self, _cx: &'a AppContext) -> impl 'a + Iterator { - self.image_anchors.iter().cloned() + pub fn insert_content(&mut self, content: Content, cx: &mut ModelContext) { + let buffer = self.buffer.read(cx); + let insertion_ix = match self + .contents + .binary_search_by(|probe| probe.cmp(&content, buffer)) + { + Ok(ix) => { + self.contents.remove(ix); + ix + } + Err(ix) => ix, + }; + self.contents.insert(insertion_ix, content); + cx.emit(ContextEvent::MessagesEdited); + } + + pub fn contents<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { + let buffer = self.buffer.read(cx); + self.contents + .iter() + .filter(|content| { + let range = content.range(); + range.start.is_valid(buffer) && range.end.is_valid(buffer) + }) + .cloned() } pub fn split_message( @@ -2533,22 +2591,14 @@ impl Context { return; } - let messages = self - .messages(cx) - .filter_map(|message| message.to_request_message(self.buffer.read(cx))) - .chain(Some(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Summarize the context into a short title without punctuation.".into(), - ], - cache: false, - })); - let request = LanguageModelRequest { - messages: messages.collect(), - tools: Vec::new(), - stop: Vec::new(), - temperature: 1.0, - }; + let mut request = self.to_completion_request(cx); + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![ + "Summarize the context into a short title without punctuation.".into(), + ], + cache: false, + }); self.pending_summary = cx.spawn(|this, mut cx| { async move { @@ -2648,10 +2698,8 @@ impl Context { cx: &'a AppContext, ) -> impl 'a + Iterator { let buffer = self.buffer.read(cx); - let messages = message_anchors.enumerate(); - let images = self.image_anchors.iter(); - Self::messages_from_iters(buffer, &self.messages_metadata, messages, images) + Self::messages_from_iters(buffer, &self.messages_metadata, message_anchors.enumerate()) } pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { @@ -2662,10 +2710,8 @@ impl Context { buffer: &'a Buffer, metadata: &'a HashMap, messages: impl Iterator + 'a, - images: impl Iterator + 'a, ) -> impl 'a + Iterator { let mut messages = messages.peekable(); - let mut images = images.peekable(); iter::from_fn(move || { if let Some((start_ix, message_anchor)) = messages.next() { @@ -2686,22 +2732,6 @@ impl Context { let message_end_anchor = message_end.unwrap_or(language::Anchor::MAX); let message_end = message_end_anchor.to_offset(buffer); - let mut image_offsets = SmallVec::new(); - while let Some(image_anchor) = images.peek() { - if image_anchor.anchor.cmp(&message_end_anchor, buffer).is_lt() { - image_offsets.push(( - image_anchor.anchor.to_offset(buffer), - MessageImage { - image_id: image_anchor.image_id, - image: image_anchor.image.clone(), - }, - )); - images.next(); - } else { - break; - } - } - return Some(Message { index_range: start_ix..end_ix, offset_range: message_start..message_end, @@ -2710,7 +2740,6 @@ impl Context { role: metadata.role, status: metadata.status.clone(), cache: metadata.cache.clone(), - image_offsets, }); } None @@ -2748,9 +2777,6 @@ impl Context { })?; if let Some(summary) = summary { - this.read_with(&cx, |this, cx| this.serialize_images(fs.clone(), cx))? - .await; - let context = this.read_with(&cx, |this, cx| this.serialize(cx))?; let mut discriminant = 1; let mut new_path; @@ -2790,45 +2816,6 @@ impl Context { }); } - pub fn serialize_images(&self, fs: Arc, cx: &AppContext) -> Task<()> { - let mut images_to_save = self - .images - .iter() - .map(|(id, (_, llm_image))| { - let fs = fs.clone(); - let llm_image = llm_image.clone(); - let id = *id; - async move { - if let Some(llm_image) = llm_image.await { - let path: PathBuf = - context_images_dir().join(&format!("{}.png.base64", id)); - if fs - .metadata(path.as_path()) - .await - .log_err() - .flatten() - .is_none() - { - fs.atomic_write(path, llm_image.source.to_string()) - .await - .log_err(); - } - } - } - }) - .collect::>(); - cx.background_executor().spawn(async move { - if fs - .create_dir(context_images_dir().as_ref()) - .await - .log_err() - .is_some() - { - while let Some(_) = images_to_save.next().await {} - } - }) - } - pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext) { let timestamp = self.next_timestamp(); let summary = self.summary.get_or_insert(ContextSummary::default()); @@ -2914,9 +2901,6 @@ pub struct SavedMessage { pub id: MessageId, pub start: usize, pub metadata: MessageMetadata, - #[serde(default)] - // This is defaulted for backwards compatibility with JSON files created before August 2024. We didn't always have this field. - pub image_offsets: Vec<(usize, u64)>, } #[derive(Serialize, Deserialize)] @@ -3102,7 +3086,6 @@ impl SavedContextV0_3_0 { timestamp, cache: None, }, - image_offsets: Vec::new(), }) }) .collect(), diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 4d6a3b6d92eb544afd01a4a57d21ac6ca7301799..b80bef5f2d623e54ff5c4626cf0ac6e99f5f8aee 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -170,12 +170,6 @@ pub fn contexts_dir() -> &'static PathBuf { }) } -/// Returns the path within the contexts directory where images from contexts are stored. -pub fn context_images_dir() -> &'static PathBuf { - static CONTEXT_IMAGES_DIR: OnceLock = OnceLock::new(); - CONTEXT_IMAGES_DIR.get_or_init(|| contexts_dir().join("images")) -} - /// Returns the path to the contexts directory. /// /// This is where the prompts for use with the Assistant are stored.