From 497356b2ba37128cb1af89395a4dcc12b6421bc3 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 4 Sep 2024 19:29:11 -0400 Subject: [PATCH] language_model: Add tool uses to message content (#17381) This PR updates the message content for an LLM request to allow it contain tool uses. We need to send the tool uses back to the model in order for it to recognize the subsequent tool results. Release Notes: - N/A --- crates/language_model/src/language_model.rs | 2 +- crates/language_model/src/provider/anthropic.rs | 3 +++ crates/language_model/src/request.rs | 14 ++++++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index d24a5f90011432add1460d5b36fa738ab89bfff7..171f5fa819b0a7487273bafedab64d95bf8586e0 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -68,7 +68,7 @@ pub enum StopReason { ToolUse, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUse { pub id: String, pub name: String, diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index f0554970b3e4226ef373abb89a9bea4d10b8eb45..5d4d4c45489ab146e954f169ff244ab158e995c3 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -267,6 +267,9 @@ pub fn count_anthropic_tokens( MessageContent::Image(image) => { tokens_from_images += image.estimate_tokens(); } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } MessageContent::ToolResult(tool_result) => { string_contents.push_str(&tool_result.content); } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 00e0c38d21b7365946f2c23f055f0c8f6d1db102..64ce33a21f0320089921502a873e0c1192f6cb7a 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,6 +1,7 @@ use std::io::{Cursor, Write}; use crate::role::Role; +use crate::LanguageModelToolUse; use base64::write::EncoderWriter; use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task}; use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder}; @@ -171,6 +172,7 @@ pub struct LanguageModelToolResult { pub enum MessageContent { Text(String), Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), ToolResult(LanguageModelToolResult), } @@ -198,8 +200,8 @@ impl LanguageModelRequestMessage { let mut string_buffer = String::new(); for string in self.content.iter().filter_map(|content| match content { MessageContent::Text(text) => Some(text), - MessageContent::Image(_) => None, MessageContent::ToolResult(tool_result) => Some(&tool_result.content), + MessageContent::ToolUse(_) | MessageContent::Image(_) => None, }) { string_buffer.push_str(string.as_str()) } @@ -213,10 +215,10 @@ impl LanguageModelRequestMessage { .get(0) .map(|content| match content { MessageContent::Text(text) => text.trim().is_empty(), - MessageContent::Image(_) => true, MessageContent::ToolResult(tool_result) => { tool_result.content.trim().is_empty() } + MessageContent::ToolUse(_) | MessageContent::Image(_) => true, }) .unwrap_or(false) } @@ -337,6 +339,14 @@ impl LanguageModelRequest { cache_control, }) } + MessageContent::ToolUse(tool_use) => { + Some(anthropic::RequestContent::ToolUse { + id: tool_use.id, + name: tool_use.name, + input: tool_use.input, + cache_control, + }) + } MessageContent::ToolResult(tool_result) => { Some(anthropic::RequestContent::ToolResult { tool_use_id: tool_result.tool_use_id,