language_model: Add tool results to message content (#17363)

Marshall Bowers created

This PR updates the message content for an LLM request to allow it
contain tool results.

Release Notes:

- N/A

Change summary

crates/anthropic/src/anthropic.rs               |  8 +
crates/language_model/src/provider/anthropic.rs |  7 
crates/language_model/src/request.rs            | 96 +++++++++++-------
3 files changed, 73 insertions(+), 38 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -423,6 +423,14 @@ pub enum RequestContent {
         #[serde(skip_serializing_if = "Option::is_none")]
         cache_control: Option<CacheControl>,
     },
+    #[serde(rename = "tool_result")]
+    ToolResult {
+        tool_use_id: String,
+        is_error: bool,
+        content: String,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
+    },
 }
 
 #[derive(Debug, Serialize, Deserialize)]

crates/language_model/src/provider/anthropic.rs 🔗

@@ -261,12 +261,15 @@ pub fn count_anthropic_tokens(
 
                 for content in message.content {
                     match content {
-                        MessageContent::Text(string) => {
-                            string_contents.push_str(&string);
+                        MessageContent::Text(text) => {
+                            string_contents.push_str(&text);
                         }
                         MessageContent::Image(image) => {
                             tokens_from_images += image.estimate_tokens();
                         }
+                        MessageContent::ToolResult(tool_result) => {
+                            string_contents.push_str(&tool_result.content);
+                        }
                     }
                 }
 

crates/language_model/src/request.rs 🔗

@@ -8,14 +8,24 @@ use serde::{Deserialize, Serialize};
 use ui::{px, SharedString};
 use util::ResultExt;
 
-#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
+#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 pub struct LanguageModelImage {
-    // A base64 encoded PNG image
+    /// A base64-encoded PNG image.
     pub source: SharedString,
     size: Size<DevicePixels>,
 }
 
-const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
+impl std::fmt::Debug for LanguageModelImage {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("LanguageModelImage")
+            .field("source", &format!("<{} bytes>", self.source.len()))
+            .field("size", &self.size)
+            .finish()
+    }
+}
+
+/// Anthropic wants uploaded images to be smaller than this in both dimensions.
+const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 
 impl LanguageModelImage {
     pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
@@ -67,7 +77,7 @@ impl LanguageModelImage {
                 }
             }
 
-            // SAFETY: The base64 encoder should not produce non-UTF8
+            // SAFETY: The base64 encoder should not produce non-UTF8.
             let source = unsafe { String::from_utf8_unchecked(base64_image) };
 
             Some(LanguageModelImage {
@@ -77,7 +87,7 @@ impl LanguageModelImage {
         })
     }
 
-    /// Resolves image into an LLM-ready format (base64)
+    /// Resolves image into an LLM-ready format (base64).
     pub fn from_render_image(data: &RenderImage) -> Option<Self> {
         let image_size = data.size(0);
 
@@ -130,7 +140,7 @@ impl LanguageModelImage {
             base64_encoder.write_all(png.as_slice()).log_err()?;
         }
 
-        // SAFETY: The base64 encoder should not produce non-UTF8
+        // SAFETY: The base64 encoder should not produce non-UTF8.
         let source = unsafe { String::from_utf8_unchecked(base64_image) };
 
         Some(LanguageModelImage {
@@ -144,35 +154,32 @@ impl LanguageModelImage {
         let height = self.size.height.0.unsigned_abs() as usize;
 
         // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
-        // Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
-        // so this method is more of a rough guess
+        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
+        // so this method is more of a rough guess.
         (width * height) / 750
     }
 }
 
-#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
+#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
+pub struct LanguageModelToolResult {
+    pub tool_use_id: String,
+    pub is_error: bool,
+    pub content: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
 pub enum MessageContent {
     Text(String),
     Image(LanguageModelImage),
-}
-
-impl std::fmt::Debug for MessageContent {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        match self {
-            MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
-            MessageContent::Image(i) => f
-                .debug_struct("MessageContent")
-                .field("image", &i.source.len())
-                .finish(),
-        }
-    }
+    ToolResult(LanguageModelToolResult),
 }
 
 impl MessageContent {
     pub fn as_string(&self) -> &str {
         match self {
-            MessageContent::Text(s) => s.as_str(),
+            MessageContent::Text(text) => text.as_str(),
             MessageContent::Image(_) => "",
+            MessageContent::ToolResult(tool_result) => tool_result.content.as_str(),
         }
     }
 }
@@ -200,8 +207,9 @@ impl LanguageModelRequestMessage {
     pub fn string_contents(&self) -> String {
         let mut string_buffer = String::new();
         for string in self.content.iter().filter_map(|content| match content {
-            MessageContent::Text(s) => Some(s),
+            MessageContent::Text(text) => Some(text),
             MessageContent::Image(_) => None,
+            MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
         }) {
             string_buffer.push_str(string.as_str())
         }
@@ -214,8 +222,11 @@ impl LanguageModelRequestMessage {
                 .content
                 .get(0)
                 .map(|content| match content {
-                    MessageContent::Text(s) => s.trim().is_empty(),
+                    MessageContent::Text(text) => text.trim().is_empty(),
                     MessageContent::Image(_) => true,
+                    MessageContent::ToolResult(tool_result) => {
+                        tool_result.content.trim().is_empty()
+                    }
                 })
                 .unwrap_or(false)
     }
@@ -316,21 +327,34 @@ impl LanguageModelRequest {
                         .content
                         .into_iter()
                         .filter_map(|content| match content {
-                            MessageContent::Text(t) if !t.is_empty() => {
-                                Some(anthropic::RequestContent::Text {
-                                    text: t,
+                            MessageContent::Text(text) => {
+                                if !text.is_empty() {
+                                    Some(anthropic::RequestContent::Text {
+                                        text,
+                                        cache_control,
+                                    })
+                                } else {
+                                    None
+                                }
+                            }
+                            MessageContent::Image(image) => {
+                                Some(anthropic::RequestContent::Image {
+                                    source: anthropic::ImageSource {
+                                        source_type: "base64".to_string(),
+                                        media_type: "image/png".to_string(),
+                                        data: image.source.to_string(),
+                                    },
+                                    cache_control,
+                                })
+                            }
+                            MessageContent::ToolResult(tool_result) => {
+                                Some(anthropic::RequestContent::ToolResult {
+                                    tool_use_id: tool_result.tool_use_id,
+                                    is_error: tool_result.is_error,
+                                    content: tool_result.content,
                                     cache_control,
                                 })
                             }
-                            MessageContent::Image(i) => Some(anthropic::RequestContent::Image {
-                                source: anthropic::ImageSource {
-                                    source_type: "base64".to_string(),
-                                    media_type: "image/png".to_string(),
-                                    data: i.source.to_string(),
-                                },
-                                cache_control,
-                            }),
-                            _ => None,
                         })
                         .collect();
                     let anthropic_role = match message.role {