From 30b21333366a61ad75d67030c391367181b15b9b Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 4 Sep 2024 13:29:01 -0400 Subject: [PATCH] language_model: Add tool results to message content (#17363) This PR updates the message content for an LLM request to allow it contain tool results. Release Notes: - N/A --- crates/anthropic/src/anthropic.rs | 8 ++ .../language_model/src/provider/anthropic.rs | 7 +- crates/language_model/src/request.rs | 96 ++++++++++++------- 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 6ac10ff79384467a419b80362eb35827910ce869..f343f47660ac0d1a7b93f05e1798aba556e320b4 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -423,6 +423,14 @@ pub enum RequestContent { #[serde(skip_serializing_if = "Option::is_none")] cache_control: Option, }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + is_error: bool, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 62b049c9ea09c6981d38af1b832116c27eaa2341..f0554970b3e4226ef373abb89a9bea4d10b8eb45 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -261,12 +261,15 @@ pub fn count_anthropic_tokens( for content in message.content { match content { - MessageContent::Text(string) => { - string_contents.push_str(&string); + MessageContent::Text(text) => { + string_contents.push_str(&text); } MessageContent::Image(image) => { tokens_from_images += image.estimate_tokens(); } + MessageContent::ToolResult(tool_result) => { + string_contents.push_str(&tool_result.content); + } } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index f03167beecdc374560103f85cf6cff326be7b204..5822bc1dad614dba47699a70c860282da01fb412 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -8,14 +8,24 @@ use serde::{Deserialize, Serialize}; use ui::{px, SharedString}; use util::ResultExt; -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct LanguageModelImage { - // A base64 encoded PNG image + /// A base64-encoded PNG image. pub source: SharedString, size: Size, } -const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions +impl std::fmt::Debug for LanguageModelImage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LanguageModelImage") + .field("source", &format!("<{} bytes>", self.source.len())) + .field("size", &self.size) + .finish() + } +} + +/// Anthropic wants uploaded images to be smaller than this in both dimensions. +const ANTHROPIC_SIZE_LIMT: f32 = 1568.; impl LanguageModelImage { pub fn from_image(data: Image, cx: &mut AppContext) -> Task> { @@ -67,7 +77,7 @@ impl LanguageModelImage { } } - // SAFETY: The base64 encoder should not produce non-UTF8 + // SAFETY: The base64 encoder should not produce non-UTF8. let source = unsafe { String::from_utf8_unchecked(base64_image) }; Some(LanguageModelImage { @@ -77,7 +87,7 @@ impl LanguageModelImage { }) } - /// Resolves image into an LLM-ready format (base64) + /// Resolves image into an LLM-ready format (base64). pub fn from_render_image(data: &RenderImage) -> Option { let image_size = data.size(0); @@ -130,7 +140,7 @@ impl LanguageModelImage { base64_encoder.write_all(png.as_slice()).log_err()?; } - // SAFETY: The base64 encoder should not produce non-UTF8 + // SAFETY: The base64 encoder should not produce non-UTF8. let source = unsafe { String::from_utf8_unchecked(base64_image) }; Some(LanguageModelImage { @@ -144,35 +154,32 @@ impl LanguageModelImage { let height = self.size.height.0.unsigned_abs() as usize; // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs - // Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this, - // so this method is more of a rough guess + // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this, + // so this method is more of a rough guess. (width * height) / 750 } } -#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub struct LanguageModelToolResult { + pub tool_use_id: String, + pub is_error: bool, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String), Image(LanguageModelImage), -} - -impl std::fmt::Debug for MessageContent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(), - MessageContent::Image(i) => f - .debug_struct("MessageContent") - .field("image", &i.source.len()) - .finish(), - } - } + ToolResult(LanguageModelToolResult), } impl MessageContent { pub fn as_string(&self) -> &str { match self { - MessageContent::Text(s) => s.as_str(), + MessageContent::Text(text) => text.as_str(), MessageContent::Image(_) => "", + MessageContent::ToolResult(tool_result) => tool_result.content.as_str(), } } } @@ -200,8 +207,9 @@ impl LanguageModelRequestMessage { pub fn string_contents(&self) -> String { let mut string_buffer = String::new(); for string in self.content.iter().filter_map(|content| match content { - MessageContent::Text(s) => Some(s), + MessageContent::Text(text) => Some(text), MessageContent::Image(_) => None, + MessageContent::ToolResult(tool_result) => Some(&tool_result.content), }) { string_buffer.push_str(string.as_str()) } @@ -214,8 +222,11 @@ impl LanguageModelRequestMessage { .content .get(0) .map(|content| match content { - MessageContent::Text(s) => s.trim().is_empty(), + MessageContent::Text(text) => text.trim().is_empty(), MessageContent::Image(_) => true, + MessageContent::ToolResult(tool_result) => { + tool_result.content.trim().is_empty() + } }) .unwrap_or(false) } @@ -316,21 +327,34 @@ impl LanguageModelRequest { .content .into_iter() .filter_map(|content| match content { - MessageContent::Text(t) if !t.is_empty() => { - Some(anthropic::RequestContent::Text { - text: t, + MessageContent::Text(text) => { + if !text.is_empty() { + Some(anthropic::RequestContent::Text { + text, + cache_control, + }) + } else { + None + } + } + MessageContent::Image(image) => { + Some(anthropic::RequestContent::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control, + }) + } + MessageContent::ToolResult(tool_result) => { + Some(anthropic::RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id, + is_error: tool_result.is_error, + content: tool_result.content, cache_control, }) } - MessageContent::Image(i) => Some(anthropic::RequestContent::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: i.source.to_string(), - }, - cache_control, - }), - _ => None, }) .collect(); let anthropic_role = match message.role {