From 68ea661711ed7f13657ac5296654f732de74ea2b Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 30 Aug 2024 14:05:55 -0400 Subject: [PATCH] assistant: Add foundation for receiving tool uses from Anthropic models (#17170) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR updates the Assistant with support for receiving tool uses from Anthropic models and capturing them as text in the context editor. This is just laying the foundation for tool use. We don't yet fulfill the tool uses yet, or define any tools for the model to use. Here's an example of what it looks like using the example `get_weather` tool from the Anthropic docs: Screenshot 2024-08-30 at 1 51 13 PM Release Notes: - N/A --- crates/anthropic/src/anthropic.rs | 106 ++++++++++++++---- crates/assistant/src/context.rs | 6 +- crates/assistant/src/inline_assistant.rs | 1 + crates/assistant/src/prompt_library.rs | 1 + .../src/terminal_inline_assistant.rs | 1 + .../language_model/src/provider/anthropic.rs | 2 +- crates/language_model/src/provider/cloud.rs | 4 +- crates/language_model/src/request.rs | 18 ++- 8 files changed, 114 insertions(+), 25 deletions(-) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 4c7e6495e0e2fdb442d485ea2918b3a825dce147..3e2f065e9594fa13b7f62113a5b9ea03bff24abe 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -330,26 +330,94 @@ pub async fn stream_completion_with_rate_limit_info( } } -pub fn extract_text_from_events( - response: impl Stream>, +pub fn extract_content_from_events( + events: Pin>>>, ) -> impl Stream> { - response.filter_map(|response| async move { - match response { - Ok(response) => match response { - Event::ContentBlockStart { content_block, .. } => match content_block { - ResponseContent::Text { text, .. } => Some(Ok(text)), - _ => None, - }, - Event::ContentBlockDelta { delta, .. } => match delta { - ContentDelta::TextDelta { text } => Some(Ok(text)), - _ => None, - }, - Event::Error { error } => Some(Err(AnthropicError::ApiError(error))), - _ => None, - }, - Err(error) => Some(Err(error)), - } - }) + struct State { + events: Pin>>>, + current_tool_use_index: Option, + } + + const INDENT: &str = " "; + const NEWLINE: char = '\n'; + + futures::stream::unfold( + State { + events, + current_tool_use_index: None, + }, + |mut state| async move { + while let Some(event) = state.events.next().await { + match event { + Ok(event) => match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + return Some((Ok(text), state)); + } + ResponseContent::ToolUse { id, name, .. } => { + state.current_tool_use_index = Some(index); + + let mut text = String::new(); + text.push(NEWLINE); + + text.push_str(""); + text.push(NEWLINE); + + text.push_str(INDENT); + text.push_str(""); + text.push_str(&id); + text.push_str(""); + text.push(NEWLINE); + + text.push_str(INDENT); + text.push_str(""); + text.push_str(&name); + text.push_str(""); + text.push(NEWLINE); + + text.push_str(INDENT); + text.push_str(""); + + return Some((Ok(text), state)); + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + return Some((Ok(text), state)); + } + ContentDelta::InputJsonDelta { partial_json } => { + if Some(index) == state.current_tool_use_index { + return Some((Ok(partial_json), state)); + } + } + }, + Event::ContentBlockStop { index } => { + if Some(index) == state.current_tool_use_index.take() { + let mut text = String::new(); + text.push_str(""); + text.push(NEWLINE); + text.push_str(""); + + return Some((Ok(text), state)); + } + } + Event::Error { error } => { + return Some((Err(AnthropicError::ApiError(error)), state)); + } + _ => {} + }, + Err(err) => { + return Some((Err(err), state)); + } + } + } + + None + }, + ) } pub async fn extract_tool_args_from_events( diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 4c8c40e7040d4f98c0aba5624d744283dc4c8046..f0cd01c4eb33be871a6a73bdca5f7bd5330ccc00 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -2048,7 +2048,8 @@ impl Context { LanguageModelRequest { messages: request_messages, - stop: vec![], + tools: Vec::new(), + stop: Vec::new(), temperature: 1.0, } } @@ -2398,7 +2399,8 @@ impl Context { })); let request = LanguageModelRequest { messages: messages.collect(), - stop: vec![], + tools: Vec::new(), + stop: Vec::new(), temperature: 1.0, }; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index deea00f78d44a4c7cdd3f957cde151c67fdda203..871cc5b0250a35a3b0e8f1c19a6b94e1fc9cad20 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2413,6 +2413,7 @@ impl Codegen { Ok(LanguageModelRequest { messages, + tools: Vec::new(), stop: vec!["|END|>".to_string()], temperature, }) diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index e1de546d15d53ef7a4bb0e6cab47caec9d0746ee..c99a7c15214d24e9760bf22d787535e3300094d9 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -794,6 +794,7 @@ impl PromptLibrary { content: vec![body.to_string().into()], cache: false, }], + tools: Vec::new(), stop: Vec::new(), temperature: 1., }, diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 05621cce1ec192d03339cf2258c8202649fd9d15..426d565cc1fa5f288b96b8540bb0d839be181570 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -282,6 +282,7 @@ impl TerminalInlineAssistant { Ok(LanguageModelRequest { messages, + tools: Vec::new(), stop: Vec::new(), temperature: 1.0, }) diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 6be9dc7413fb137602aaf8e0dac7fa7c1bb77c74..e4bb94a738ed1a742477270005bb022dbd9bc392 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -370,7 +370,7 @@ impl LanguageModel for AnthropicModel { let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { let response = request.await.map_err(|err| anyhow!(err))?; - Ok(anthropic::extract_text_from_events(response)) + Ok(anthropic::extract_content_from_events(response)) }); async move { Ok(future diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 11f91c457af86d67dae1155c39402f0b3568bab0..d4166fdfb5bba8c01f53661b80ff4714f396525f 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -515,9 +515,9 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok(anthropic::extract_text_from_events( + Ok(anthropic::extract_content_from_events(Box::pin( response_lines(response).map_err(AnthropicError::Other), - )) + ))) }); async move { Ok(future diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index bd9804ea0ab7b7952b14ada40da002e9574dd6c3..f03167beecdc374560103f85cf6cff326be7b204 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -221,9 +221,17 @@ impl LanguageModelRequestMessage { } } +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelRequestTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub struct LanguageModelRequest { pub messages: Vec, + pub tools: Vec, pub stop: Vec, pub temperature: f32, } @@ -355,7 +363,15 @@ impl LanguageModelRequest { messages: new_messages, max_tokens: max_output_tokens, system: Some(system_message), - tools: Vec::new(), + tools: self + .tools + .into_iter() + .map(|tool| anthropic::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), tool_choice: None, metadata: None, stop_sequences: Vec::new(),