Detailed changes
@@ -330,26 +330,94 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
-pub fn extract_text_from_events(
- response: impl Stream<Item = Result<Event, AnthropicError>>,
+pub fn extract_content_from_events(
+ events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<String, AnthropicError>> {
- response.filter_map(|response| async move {
- match response {
- Ok(response) => match response {
- Event::ContentBlockStart { content_block, .. } => match content_block {
- ResponseContent::Text { text, .. } => Some(Ok(text)),
- _ => None,
- },
- Event::ContentBlockDelta { delta, .. } => match delta {
- ContentDelta::TextDelta { text } => Some(Ok(text)),
- _ => None,
- },
- Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
- _ => None,
- },
- Err(error) => Some(Err(error)),
- }
- })
+ struct State {
+ events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+ current_tool_use_index: Option<usize>,
+ }
+
+ const INDENT: &str = " ";
+ const NEWLINE: char = '\n';
+
+ futures::stream::unfold(
+ State {
+ events,
+ current_tool_use_index: None,
+ },
+ |mut state| async move {
+ while let Some(event) = state.events.next().await {
+ match event {
+ Ok(event) => match event {
+ Event::ContentBlockStart {
+ index,
+ content_block,
+ } => match content_block {
+ ResponseContent::Text { text } => {
+ return Some((Ok(text), state));
+ }
+ ResponseContent::ToolUse { id, name, .. } => {
+ state.current_tool_use_index = Some(index);
+
+ let mut text = String::new();
+ text.push(NEWLINE);
+
+ text.push_str("<tool_use>");
+ text.push(NEWLINE);
+
+ text.push_str(INDENT);
+ text.push_str("<id>");
+ text.push_str(&id);
+ text.push_str("</id>");
+ text.push(NEWLINE);
+
+ text.push_str(INDENT);
+ text.push_str("<name>");
+ text.push_str(&name);
+ text.push_str("</name>");
+ text.push(NEWLINE);
+
+ text.push_str(INDENT);
+ text.push_str("<input>");
+
+ return Some((Ok(text), state));
+ }
+ },
+ Event::ContentBlockDelta { index, delta } => match delta {
+ ContentDelta::TextDelta { text } => {
+ return Some((Ok(text), state));
+ }
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if Some(index) == state.current_tool_use_index {
+ return Some((Ok(partial_json), state));
+ }
+ }
+ },
+ Event::ContentBlockStop { index } => {
+ if Some(index) == state.current_tool_use_index.take() {
+ let mut text = String::new();
+ text.push_str("</input>");
+ text.push(NEWLINE);
+ text.push_str("</tool_use>");
+
+ return Some((Ok(text), state));
+ }
+ }
+ Event::Error { error } => {
+ return Some((Err(AnthropicError::ApiError(error)), state));
+ }
+ _ => {}
+ },
+ Err(err) => {
+ return Some((Err(err), state));
+ }
+ }
+ }
+
+ None
+ },
+ )
}
pub async fn extract_tool_args_from_events(
@@ -2048,7 +2048,8 @@ impl Context {
LanguageModelRequest {
messages: request_messages,
- stop: vec![],
+ tools: Vec::new(),
+ stop: Vec::new(),
temperature: 1.0,
}
}
@@ -2398,7 +2399,8 @@ impl Context {
}));
let request = LanguageModelRequest {
messages: messages.collect(),
- stop: vec![],
+ tools: Vec::new(),
+ stop: Vec::new(),
temperature: 1.0,
};
@@ -2413,6 +2413,7 @@ impl Codegen {
Ok(LanguageModelRequest {
messages,
+ tools: Vec::new(),
stop: vec!["|END|>".to_string()],
temperature,
})
@@ -794,6 +794,7 @@ impl PromptLibrary {
content: vec![body.to_string().into()],
cache: false,
}],
+ tools: Vec::new(),
stop: Vec::new(),
temperature: 1.,
},
@@ -282,6 +282,7 @@ impl TerminalInlineAssistant {
Ok(LanguageModelRequest {
messages,
+ tools: Vec::new(),
stop: Vec::new(),
temperature: 1.0,
})
@@ -370,7 +370,7 @@ impl LanguageModel for AnthropicModel {
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
- Ok(anthropic::extract_text_from_events(response))
+ Ok(anthropic::extract_content_from_events(response))
});
async move {
Ok(future
@@ -515,9 +515,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- Ok(anthropic::extract_text_from_events(
+ Ok(anthropic::extract_content_from_events(Box::pin(
response_lines(response).map_err(AnthropicError::Other),
- ))
+ )))
});
async move {
Ok(future
@@ -221,9 +221,17 @@ impl LanguageModelRequestMessage {
}
}
+#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelRequestTool {
+ pub name: String,
+ pub description: String,
+ pub input_schema: serde_json::Value,
+}
+
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct LanguageModelRequest {
pub messages: Vec<LanguageModelRequestMessage>,
+ pub tools: Vec<LanguageModelRequestTool>,
pub stop: Vec<String>,
pub temperature: f32,
}
@@ -355,7 +363,15 @@ impl LanguageModelRequest {
messages: new_messages,
max_tokens: max_output_tokens,
system: Some(system_message),
- tools: Vec::new(),
+ tools: self
+ .tools
+ .into_iter()
+ .map(|tool| anthropic::Tool {
+ name: tool.name,
+ description: tool.description,
+ input_schema: tool.input_schema,
+ })
+ .collect(),
tool_choice: None,
metadata: None,
stop_sequences: Vec::new(),