diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 4c7e6495e0e2fdb442d485ea2918b3a825dce147..3e2f065e9594fa13b7f62113a5b9ea03bff24abe 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -330,26 +330,94 @@ pub async fn stream_completion_with_rate_limit_info( } } -pub fn extract_text_from_events( - response: impl Stream>, +pub fn extract_content_from_events( + events: Pin>>>, ) -> impl Stream> { - response.filter_map(|response| async move { - match response { - Ok(response) => match response { - Event::ContentBlockStart { content_block, .. } => match content_block { - ResponseContent::Text { text, .. } => Some(Ok(text)), - _ => None, - }, - Event::ContentBlockDelta { delta, .. } => match delta { - ContentDelta::TextDelta { text } => Some(Ok(text)), - _ => None, - }, - Event::Error { error } => Some(Err(AnthropicError::ApiError(error))), - _ => None, - }, - Err(error) => Some(Err(error)), - } - }) + struct State { + events: Pin>>>, + current_tool_use_index: Option, + } + + const INDENT: &str = " "; + const NEWLINE: char = '\n'; + + futures::stream::unfold( + State { + events, + current_tool_use_index: None, + }, + |mut state| async move { + while let Some(event) = state.events.next().await { + match event { + Ok(event) => match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + return Some((Ok(text), state)); + } + ResponseContent::ToolUse { id, name, .. } => { + state.current_tool_use_index = Some(index); + + let mut text = String::new(); + text.push(NEWLINE); + + text.push_str(""); + text.push(NEWLINE); + + text.push_str(INDENT); + text.push_str(""); + text.push_str(&id); + text.push_str(""); + text.push(NEWLINE); + + text.push_str(INDENT); + text.push_str(""); + text.push_str(&name); + text.push_str(""); + text.push(NEWLINE); + + text.push_str(INDENT); + text.push_str(""); + + return Some((Ok(text), state)); + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + return Some((Ok(text), state)); + } + ContentDelta::InputJsonDelta { partial_json } => { + if Some(index) == state.current_tool_use_index { + return Some((Ok(partial_json), state)); + } + } + }, + Event::ContentBlockStop { index } => { + if Some(index) == state.current_tool_use_index.take() { + let mut text = String::new(); + text.push_str(""); + text.push(NEWLINE); + text.push_str(""); + + return Some((Ok(text), state)); + } + } + Event::Error { error } => { + return Some((Err(AnthropicError::ApiError(error)), state)); + } + _ => {} + }, + Err(err) => { + return Some((Err(err), state)); + } + } + } + + None + }, + ) } pub async fn extract_tool_args_from_events( diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 4c8c40e7040d4f98c0aba5624d744283dc4c8046..f0cd01c4eb33be871a6a73bdca5f7bd5330ccc00 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -2048,7 +2048,8 @@ impl Context { LanguageModelRequest { messages: request_messages, - stop: vec![], + tools: Vec::new(), + stop: Vec::new(), temperature: 1.0, } } @@ -2398,7 +2399,8 @@ impl Context { })); let request = LanguageModelRequest { messages: messages.collect(), - stop: vec![], + tools: Vec::new(), + stop: Vec::new(), temperature: 1.0, }; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index deea00f78d44a4c7cdd3f957cde151c67fdda203..871cc5b0250a35a3b0e8f1c19a6b94e1fc9cad20 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2413,6 +2413,7 @@ impl Codegen { Ok(LanguageModelRequest { messages, + tools: Vec::new(), stop: vec!["|END|>".to_string()], temperature, }) diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index e1de546d15d53ef7a4bb0e6cab47caec9d0746ee..c99a7c15214d24e9760bf22d787535e3300094d9 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -794,6 +794,7 @@ impl PromptLibrary { content: vec![body.to_string().into()], cache: false, }], + tools: Vec::new(), stop: Vec::new(), temperature: 1., }, diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 05621cce1ec192d03339cf2258c8202649fd9d15..426d565cc1fa5f288b96b8540bb0d839be181570 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -282,6 +282,7 @@ impl TerminalInlineAssistant { Ok(LanguageModelRequest { messages, + tools: Vec::new(), stop: Vec::new(), temperature: 1.0, }) diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 6be9dc7413fb137602aaf8e0dac7fa7c1bb77c74..e4bb94a738ed1a742477270005bb022dbd9bc392 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -370,7 +370,7 @@ impl LanguageModel for AnthropicModel { let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { let response = request.await.map_err(|err| anyhow!(err))?; - Ok(anthropic::extract_text_from_events(response)) + Ok(anthropic::extract_content_from_events(response)) }); async move { Ok(future diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 11f91c457af86d67dae1155c39402f0b3568bab0..d4166fdfb5bba8c01f53661b80ff4714f396525f 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -515,9 +515,9 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok(anthropic::extract_text_from_events( + Ok(anthropic::extract_content_from_events(Box::pin( response_lines(response).map_err(AnthropicError::Other), - )) + ))) }); async move { Ok(future diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index bd9804ea0ab7b7952b14ada40da002e9574dd6c3..f03167beecdc374560103f85cf6cff326be7b204 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -221,9 +221,17 @@ impl LanguageModelRequestMessage { } } +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelRequestTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub struct LanguageModelRequest { pub messages: Vec, + pub tools: Vec, pub stop: Vec, pub temperature: f32, } @@ -355,7 +363,15 @@ impl LanguageModelRequest { messages: new_messages, max_tokens: max_output_tokens, system: Some(system_message), - tools: Vec::new(), + tools: self + .tools + .into_iter() + .map(|tool| anthropic::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), tool_choice: None, metadata: None, stop_sequences: Vec::new(),