assistant: Pass up tool results in LLM request messages (#17656)

Marshall Bowers and Antonio created

This PR makes it so we pass up the tool results in the `tool_results`
field in the request message to the LLM.

This required reworking how we track non-text content in the context
editor.

We also removed serialization of images in context history, as we were
never deserializing it, and thus it was unneeded.

Release Notes:

- N/A

---------

Co-authored-by: Antonio <antonio@zed.dev>

Change summary

crates/assistant/src/assistant_panel.rs |  39 ++
crates/assistant/src/context.rs         | 381 ++++++++++++--------------
crates/paths/src/paths.rs               |   6 
3 files changed, 215 insertions(+), 211 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -11,7 +11,7 @@ use crate::{
     },
     slash_command_picker,
     terminal_inline_assistant::TerminalInlineAssistant,
-    Assist, CacheStatus, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
+    Assist, CacheStatus, ConfirmCommand, Content, Context, ContextEvent, ContextId, ContextStore,
     ContextStoreEvent, CycleMessageRole, DeployHistory, DeployPromptLibrary, InlineAssistId,
     InlineAssistant, InsertDraggedFiles, InsertIntoEditor, Message, MessageId, MessageMetadata,
     MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, PendingSlashCommand,
@@ -46,6 +46,7 @@ use indexed_docs::IndexedDocsStore;
 use language::{
     language_settings::SoftWrap, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset,
 };
+use language_model::LanguageModelToolUse;
 use language_model::{
     provider::cloud::PROVIDER_ID, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelRegistry, Role,
@@ -1995,6 +1996,20 @@ impl ContextEditor {
                             let buffer_row = MultiBufferRow(start.to_point(&buffer).row);
                             buffer_rows_to_fold.insert(buffer_row);
 
+                            self.context.update(cx, |context, cx| {
+                                context.insert_content(
+                                    Content::ToolUse {
+                                        range: tool_use.source_range.clone(),
+                                        tool_use: LanguageModelToolUse {
+                                            id: tool_use.id.to_string(),
+                                            name: tool_use.name.clone(),
+                                            input: tool_use.input.clone(),
+                                        },
+                                    },
+                                    cx,
+                                );
+                            });
+
                             Crease::new(
                                 start..end,
                                 placeholder,
@@ -3538,7 +3553,7 @@ impl ContextEditor {
                     let image_id = image.id();
                     context.insert_image(image, cx);
                     for image_position in image_positions.iter() {
-                        context.insert_image_anchor(image_id, image_position.text_anchor, cx);
+                        context.insert_image_content(image_id, image_position.text_anchor, cx);
                     }
                 }
             });
@@ -3553,11 +3568,23 @@ impl ContextEditor {
             let new_blocks = self
                 .context
                 .read(cx)
-                .images(cx)
-                .filter_map(|image| {
+                .contents(cx)
+                .filter_map(|content| {
+                    if let Content::Image {
+                        anchor,
+                        render_image,
+                        ..
+                    } = content
+                    {
+                        Some((anchor, render_image))
+                    } else {
+                        None
+                    }
+                })
+                .filter_map(|(anchor, render_image)| {
                     const MAX_HEIGHT_IN_LINES: u32 = 8;
-                    let anchor = buffer.anchor_in_excerpt(excerpt_id, image.anchor).unwrap();
-                    let image = image.render_image.clone();
+                    let anchor = buffer.anchor_in_excerpt(excerpt_id, anchor).unwrap();
+                    let image = render_image.clone();
                     anchor.is_valid(&buffer).then(|| BlockProperties {
                         position: anchor,
                         height: MAX_HEIGHT_IN_LINES,

crates/assistant/src/context.rs 🔗

@@ -17,7 +17,6 @@ use feature_flags::{FeatureFlag, FeatureFlagAppExt};
 use fs::{Fs, RemoveOptions};
 use futures::{
     future::{self, Shared},
-    stream::FuturesUnordered,
     FutureExt, StreamExt,
 };
 use gpui::{
@@ -29,10 +28,11 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
 use language_model::{
     LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
     LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
-    LanguageModelRequestTool, MessageContent, Role, StopReason,
+    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
+    StopReason,
 };
 use open_ai::Model as OpenAiModel;
-use paths::{context_images_dir, contexts_dir};
+use paths::contexts_dir;
 use project::Project;
 use serde::{Deserialize, Serialize};
 use smallvec::SmallVec;
@@ -377,23 +377,8 @@ impl MessageMetadata {
     }
 }
 
-#[derive(Clone, Debug)]
-pub struct MessageImage {
-    image_id: u64,
-    image: Shared<Task<Option<LanguageModelImage>>>,
-}
-
-impl PartialEq for MessageImage {
-    fn eq(&self, other: &Self) -> bool {
-        self.image_id == other.image_id
-    }
-}
-
-impl Eq for MessageImage {}
-
 #[derive(Clone, Debug)]
 pub struct Message {
-    pub image_offsets: SmallVec<[(usize, MessageImage); 1]>,
     pub offset_range: Range<usize>,
     pub index_range: Range<usize>,
     pub anchor_range: Range<language::Anchor>,
@@ -403,62 +388,45 @@ pub struct Message {
     pub cache: Option<MessageCacheMetadata>,
 }
 
-impl Message {
-    fn to_request_message(&self, buffer: &Buffer) -> Option<LanguageModelRequestMessage> {
-        let mut content = Vec::new();
-
-        let mut range_start = self.offset_range.start;
-        for (image_offset, message_image) in self.image_offsets.iter() {
-            if *image_offset != range_start {
-                if let Some(text) = Self::collect_text_content(buffer, range_start..*image_offset) {
-                    content.push(text);
-                }
-            }
-
-            if let Some(image) = message_image.image.clone().now_or_never().flatten() {
-                content.push(language_model::MessageContent::Image(image));
-            }
-
-            range_start = *image_offset;
-        }
-
-        if range_start != self.offset_range.end {
-            if let Some(text) =
-                Self::collect_text_content(buffer, range_start..self.offset_range.end)
-            {
-                content.push(text);
-            }
-        }
+#[derive(Debug, Clone)]
+pub enum Content {
+    Image {
+        anchor: language::Anchor,
+        image_id: u64,
+        render_image: Arc<RenderImage>,
+        image: Shared<Task<Option<LanguageModelImage>>>,
+    },
+    ToolUse {
+        range: Range<language::Anchor>,
+        tool_use: LanguageModelToolUse,
+    },
+    ToolResult {
+        range: Range<language::Anchor>,
+        tool_use_id: Arc<str>,
+    },
+}
 
-        if content.is_empty() {
-            return None;
+impl Content {
+    fn range(&self) -> Range<language::Anchor> {
+        match self {
+            Self::Image { anchor, .. } => *anchor..*anchor,
+            Self::ToolUse { range, .. } | Self::ToolResult { range, .. } => range.clone(),
         }
-
-        Some(LanguageModelRequestMessage {
-            role: self.role,
-            content,
-            cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor),
-        })
     }
 
-    fn collect_text_content(buffer: &Buffer, range: Range<usize>) -> Option<MessageContent> {
-        let text: String = buffer.text_for_range(range.clone()).collect();
-        if text.trim().is_empty() {
-            None
+    fn cmp(&self, other: &Self, buffer: &BufferSnapshot) -> Ordering {
+        let self_range = self.range();
+        let other_range = other.range();
+        if self_range.end.cmp(&other_range.start, buffer).is_lt() {
+            Ordering::Less
+        } else if self_range.start.cmp(&other_range.end, buffer).is_gt() {
+            Ordering::Greater
         } else {
-            Some(MessageContent::Text(text))
+            Ordering::Equal
         }
     }
 }
 
-#[derive(Clone, Debug)]
-pub struct ImageAnchor {
-    pub anchor: language::Anchor,
-    pub image_id: u64,
-    pub render_image: Arc<RenderImage>,
-    pub image: Shared<Task<Option<LanguageModelImage>>>,
-}
-
 struct PendingCompletion {
     id: usize,
     assistant_message_id: MessageId,
@@ -501,7 +469,7 @@ pub struct Context {
     pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
     message_anchors: Vec<MessageAnchor>,
     images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
-    image_anchors: Vec<ImageAnchor>,
+    contents: Vec<Content>,
     messages_metadata: HashMap<MessageId, MessageMetadata>,
     summary: Option<ContextSummary>,
     pending_summary: Task<Option<()>>,
@@ -595,7 +563,7 @@ impl Context {
             pending_ops: Vec::new(),
             operations: Vec::new(),
             message_anchors: Default::default(),
-            image_anchors: Default::default(),
+            contents: Default::default(),
             images: Default::default(),
             messages_metadata: Default::default(),
             pending_slash_commands: Vec::new(),
@@ -659,11 +627,6 @@ impl Context {
                     id: message.id,
                     start: message.offset_range.start,
                     metadata: self.messages_metadata[&message.id].clone(),
-                    image_offsets: message
-                        .image_offsets
-                        .iter()
-                        .map(|image_offset| (image_offset.0, image_offset.1.image_id))
-                        .collect(),
                 })
                 .collect(),
             summary: self
@@ -1957,6 +1920,14 @@ impl Context {
                             output_range
                         });
 
+                        this.insert_content(
+                            Content::ToolResult {
+                                range: anchor_range.clone(),
+                                tool_use_id: tool_use_id.clone(),
+                            },
+                            cx,
+                        );
+
                         cx.emit(ContextEvent::ToolFinished {
                             tool_use_id,
                             output_range: anchor_range,
@@ -2038,6 +2009,7 @@ impl Context {
                 let stream_completion = async {
                     let request_start = Instant::now();
                     let mut events = stream.await?;
+                    let mut stop_reason = StopReason::EndTurn;
 
                     while let Some(event) = events.next().await {
                         if response_latency.is_none() {
@@ -2050,7 +2022,7 @@ impl Context {
                                 .message_anchors
                                 .iter()
                                 .position(|message| message.id == assistant_message_id)?;
-                            let event_to_emit = this.buffer.update(cx, |buffer, cx| {
+                            this.buffer.update(cx, |buffer, cx| {
                                 let message_old_end_offset = this.message_anchors[message_ix + 1..]
                                     .iter()
                                     .find(|message| message.start.is_valid(buffer))
@@ -2059,13 +2031,9 @@ impl Context {
                                     });
 
                                 match event {
-                                    LanguageModelCompletionEvent::Stop(reason) => match reason {
-                                        StopReason::ToolUse => {
-                                            return Some(ContextEvent::UsePendingTools);
-                                        }
-                                        StopReason::EndTurn => {}
-                                        StopReason::MaxTokens => {}
-                                    },
+                                    LanguageModelCompletionEvent::Stop(reason) => {
+                                        stop_reason = reason;
+                                    }
                                     LanguageModelCompletionEvent::Text(chunk) => {
                                         buffer.edit(
                                             [(
@@ -2116,14 +2084,9 @@ impl Context {
                                         );
                                     }
                                 }
-
-                                None
                             });
 
                             cx.emit(ContextEvent::StreamedCompletion);
-                            if let Some(event) = event_to_emit {
-                                cx.emit(event);
-                            }
 
                             Some(())
                         })?;
@@ -2136,13 +2099,14 @@ impl Context {
                         this.update_cache_status_for_completion(cx);
                     })?;
 
-                    anyhow::Ok(())
+                    anyhow::Ok(stop_reason)
                 };
 
                 let result = stream_completion.await;
 
                 this.update(&mut cx, |this, cx| {
                     let error_message = result
+                        .as_ref()
                         .err()
                         .map(|error| error.to_string().trim().to_string());
 
@@ -2170,6 +2134,16 @@ impl Context {
                             error_message,
                         );
                     }
+
+                    if let Ok(stop_reason) = result {
+                        match stop_reason {
+                            StopReason::ToolUse => {
+                                cx.emit(ContextEvent::UsePendingTools);
+                            }
+                            StopReason::EndTurn => {}
+                            StopReason::MaxTokens => {}
+                        }
+                    }
                 })
                 .ok();
             }
@@ -2186,18 +2160,94 @@ impl Context {
 
     pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
         let buffer = self.buffer.read(cx);
-        let request_messages = self
-            .messages(cx)
-            .filter(|message| message.status == MessageStatus::Done)
-            .filter_map(|message| message.to_request_message(&buffer))
-            .collect();
 
-        LanguageModelRequest {
-            messages: request_messages,
+        let mut contents = self.contents(cx).peekable();
+
+        fn collect_text_content(buffer: &Buffer, range: Range<usize>) -> Option<String> {
+            let text: String = buffer.text_for_range(range.clone()).collect();
+            if text.trim().is_empty() {
+                None
+            } else {
+                Some(text)
+            }
+        }
+
+        let mut completion_request = LanguageModelRequest {
+            messages: Vec::new(),
             tools: Vec::new(),
             stop: Vec::new(),
             temperature: 1.0,
+        };
+        for message in self.messages(cx) {
+            if message.status != MessageStatus::Done {
+                continue;
+            }
+
+            let mut offset = message.offset_range.start;
+            let mut request_message = LanguageModelRequestMessage {
+                role: message.role,
+                content: Vec::new(),
+                cache: message
+                    .cache
+                    .as_ref()
+                    .map_or(false, |cache| cache.is_anchor),
+            };
+
+            while let Some(content) = contents.peek() {
+                if content
+                    .range()
+                    .end
+                    .cmp(&message.anchor_range.end, buffer)
+                    .is_lt()
+                {
+                    let content = contents.next().unwrap();
+                    let range = content.range().to_offset(buffer);
+                    request_message.content.extend(
+                        collect_text_content(buffer, offset..range.start).map(MessageContent::Text),
+                    );
+
+                    match content {
+                        Content::Image { image, .. } => {
+                            if let Some(image) = image.clone().now_or_never().flatten() {
+                                request_message
+                                    .content
+                                    .push(language_model::MessageContent::Image(image));
+                            }
+                        }
+                        Content::ToolUse { tool_use, .. } => {
+                            request_message
+                                .content
+                                .push(language_model::MessageContent::ToolUse(tool_use.clone()));
+                        }
+                        Content::ToolResult { tool_use_id, .. } => {
+                            request_message.content.push(
+                                language_model::MessageContent::ToolResult(
+                                    LanguageModelToolResult {
+                                        tool_use_id: tool_use_id.to_string(),
+                                        is_error: false,
+                                        content: collect_text_content(buffer, range.clone())
+                                            .unwrap_or_default(),
+                                    },
+                                ),
+                            );
+                        }
+                    }
+
+                    offset = range.end;
+                } else {
+                    break;
+                }
+            }
+
+            request_message.content.extend(
+                collect_text_content(buffer, offset..message.offset_range.end)
+                    .map(MessageContent::Text),
+            );
+
+            completion_request.messages.push(request_message);
         }
+
+        completion_request
     }
 
     pub fn cancel_last_assist(&mut self, cx: &mut ModelContext<Self>) -> bool {
@@ -2335,42 +2385,50 @@ impl Context {
         Some(())
     }
 
-    pub fn insert_image_anchor(
+    pub fn insert_image_content(
         &mut self,
         image_id: u64,
         anchor: language::Anchor,
         cx: &mut ModelContext<Self>,
-    ) -> bool {
-        cx.emit(ContextEvent::MessagesEdited);
-
-        let buffer = self.buffer.read(cx);
-        let insertion_ix = match self
-            .image_anchors
-            .binary_search_by(|existing_anchor| anchor.cmp(&existing_anchor.anchor, buffer))
-        {
-            Ok(ix) => ix,
-            Err(ix) => ix,
-        };
-
+    ) {
         if let Some((render_image, image)) = self.images.get(&image_id) {
-            self.image_anchors.insert(
-                insertion_ix,
-                ImageAnchor {
+            self.insert_content(
+                Content::Image {
                     anchor,
                     image_id,
                     image: image.clone(),
                     render_image: render_image.clone(),
                 },
+                cx,
             );
-
-            true
-        } else {
-            false
         }
     }
 
-    pub fn images<'a>(&'a self, _cx: &'a AppContext) -> impl 'a + Iterator<Item = ImageAnchor> {
-        self.image_anchors.iter().cloned()
+    pub fn insert_content(&mut self, content: Content, cx: &mut ModelContext<Self>) {
+        let buffer = self.buffer.read(cx);
+        let insertion_ix = match self
+            .contents
+            .binary_search_by(|probe| probe.cmp(&content, buffer))
+        {
+            Ok(ix) => {
+                self.contents.remove(ix);
+                ix
+            }
+            Err(ix) => ix,
+        };
+        self.contents.insert(insertion_ix, content);
+        cx.emit(ContextEvent::MessagesEdited);
+    }
+
+    pub fn contents<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Content> {
+        let buffer = self.buffer.read(cx);
+        self.contents
+            .iter()
+            .filter(|content| {
+                let range = content.range();
+                range.start.is_valid(buffer) && range.end.is_valid(buffer)
+            })
+            .cloned()
     }
 
     pub fn split_message(
@@ -2533,22 +2591,14 @@ impl Context {
                 return;
             }
 
-            let messages = self
-                .messages(cx)
-                .filter_map(|message| message.to_request_message(self.buffer.read(cx)))
-                .chain(Some(LanguageModelRequestMessage {
-                    role: Role::User,
-                    content: vec![
-                        "Summarize the context into a short title without punctuation.".into(),
-                    ],
-                    cache: false,
-                }));
-            let request = LanguageModelRequest {
-                messages: messages.collect(),
-                tools: Vec::new(),
-                stop: Vec::new(),
-                temperature: 1.0,
-            };
+            let mut request = self.to_completion_request(cx);
+            request.messages.push(LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![
+                    "Summarize the context into a short title without punctuation.".into(),
+                ],
+                cache: false,
+            });
 
             self.pending_summary = cx.spawn(|this, mut cx| {
                 async move {
@@ -2648,10 +2698,8 @@ impl Context {
         cx: &'a AppContext,
     ) -> impl 'a + Iterator<Item = Message> {
         let buffer = self.buffer.read(cx);
-        let messages = message_anchors.enumerate();
-        let images = self.image_anchors.iter();
 
-        Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
+        Self::messages_from_iters(buffer, &self.messages_metadata, message_anchors.enumerate())
     }
 
     pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@@ -2662,10 +2710,8 @@ impl Context {
         buffer: &'a Buffer,
         metadata: &'a HashMap<MessageId, MessageMetadata>,
         messages: impl Iterator<Item = (usize, &'a MessageAnchor)> + 'a,
-        images: impl Iterator<Item = &'a ImageAnchor> + 'a,
     ) -> impl 'a + Iterator<Item = Message> {
         let mut messages = messages.peekable();
-        let mut images = images.peekable();
 
         iter::from_fn(move || {
             if let Some((start_ix, message_anchor)) = messages.next() {
@@ -2686,22 +2732,6 @@ impl Context {
                 let message_end_anchor = message_end.unwrap_or(language::Anchor::MAX);
                 let message_end = message_end_anchor.to_offset(buffer);
 
-                let mut image_offsets = SmallVec::new();
-                while let Some(image_anchor) = images.peek() {
-                    if image_anchor.anchor.cmp(&message_end_anchor, buffer).is_lt() {
-                        image_offsets.push((
-                            image_anchor.anchor.to_offset(buffer),
-                            MessageImage {
-                                image_id: image_anchor.image_id,
-                                image: image_anchor.image.clone(),
-                            },
-                        ));
-                        images.next();
-                    } else {
-                        break;
-                    }
-                }
-
                 return Some(Message {
                     index_range: start_ix..end_ix,
                     offset_range: message_start..message_end,
@@ -2710,7 +2740,6 @@ impl Context {
                     role: metadata.role,
                     status: metadata.status.clone(),
                     cache: metadata.cache.clone(),
-                    image_offsets,
                 });
             }
             None
@@ -2748,9 +2777,6 @@ impl Context {
             })?;
 
             if let Some(summary) = summary {
-                this.read_with(&cx, |this, cx| this.serialize_images(fs.clone(), cx))?
-                    .await;
-
                 let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
                 let mut discriminant = 1;
                 let mut new_path;
@@ -2790,45 +2816,6 @@ impl Context {
         });
     }
 
-    pub fn serialize_images(&self, fs: Arc<dyn Fs>, cx: &AppContext) -> Task<()> {
-        let mut images_to_save = self
-            .images
-            .iter()
-            .map(|(id, (_, llm_image))| {
-                let fs = fs.clone();
-                let llm_image = llm_image.clone();
-                let id = *id;
-                async move {
-                    if let Some(llm_image) = llm_image.await {
-                        let path: PathBuf =
-                            context_images_dir().join(&format!("{}.png.base64", id));
-                        if fs
-                            .metadata(path.as_path())
-                            .await
-                            .log_err()
-                            .flatten()
-                            .is_none()
-                        {
-                            fs.atomic_write(path, llm_image.source.to_string())
-                                .await
-                                .log_err();
-                        }
-                    }
-                }
-            })
-            .collect::<FuturesUnordered<_>>();
-        cx.background_executor().spawn(async move {
-            if fs
-                .create_dir(context_images_dir().as_ref())
-                .await
-                .log_err()
-                .is_some()
-            {
-                while let Some(_) = images_to_save.next().await {}
-            }
-        })
-    }
-
     pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
         let timestamp = self.next_timestamp();
         let summary = self.summary.get_or_insert(ContextSummary::default());
@@ -2914,9 +2901,6 @@ pub struct SavedMessage {
     pub id: MessageId,
     pub start: usize,
     pub metadata: MessageMetadata,
-    #[serde(default)]
-    // This is defaulted for backwards compatibility with JSON files created before August 2024. We didn't always have this field.
-    pub image_offsets: Vec<(usize, u64)>,
 }
 
 #[derive(Serialize, Deserialize)]
@@ -3102,7 +3086,6 @@ impl SavedContextV0_3_0 {
                             timestamp,
                             cache: None,
                         },
-                        image_offsets: Vec::new(),
                     })
                 })
                 .collect(),

crates/paths/src/paths.rs 🔗

@@ -170,12 +170,6 @@ pub fn contexts_dir() -> &'static PathBuf {
     })
 }
 
-/// Returns the path within the contexts directory where images from contexts are stored.
-pub fn context_images_dir() -> &'static PathBuf {
-    static CONTEXT_IMAGES_DIR: OnceLock<PathBuf> = OnceLock::new();
-    CONTEXT_IMAGES_DIR.get_or_init(|| contexts_dir().join("images"))
-}
-
 /// Returns the path to the contexts directory.
 ///
 /// This is where the prompts for use with the Assistant are stored.