Cargo.lock 🔗
@@ -243,7 +243,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
- "collections",
"futures 0.3.30",
"http_client",
"isahc",
Marshall Bowers created
This PR makes it so we propagate the `stop_reason` from Anthropic up to
the Assistant so that we can take action based on it.
The `extract_content_from_events` function was moved from `anthropic` to
the `anthropic` module in `language_model` since it is more useful if it
is able to name the `LanguageModelCompletionEvent` type, as otherwise
we'd need an additional layer of plumbing.
Release Notes:
- N/A
Cargo.lock | 1
crates/anthropic/Cargo.toml | 1
crates/anthropic/src/anthropic.rs | 91 -----------
crates/assistant/src/context.rs | 5
crates/language_model/src/language_model.rs | 10 +
crates/language_model/src/provider/anthropic.rs | 150 +++++++++++++++---
crates/language_model/src/provider/cloud.rs | 29 ---
7 files changed, 143 insertions(+), 144 deletions(-)
@@ -243,7 +243,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
- "collections",
"futures 0.3.30",
"http_client",
"isahc",
@@ -18,7 +18,6 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
-collections.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true
@@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
-use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
@@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
-use util::{maybe, ResultExt as _};
+use util::ResultExt as _;
pub use supported_countries::*;
@@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
-pub fn extract_content_from_events(
- events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
-) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
- struct RawToolUse {
- id: String,
- name: String,
- input_json: String,
- }
-
- struct State {
- events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
- tool_uses_by_index: HashMap<usize, RawToolUse>,
- }
-
- futures::stream::unfold(
- State {
- events,
- tool_uses_by_index: HashMap::default(),
- },
- |mut state| async move {
- while let Some(event) = state.events.next().await {
- match event {
- Ok(event) => match event {
- Event::ContentBlockStart {
- index,
- content_block,
- } => match content_block {
- ResponseContent::Text { text } => {
- return Some((Some(Ok(ResponseContent::Text { text })), state));
- }
- ResponseContent::ToolUse { id, name, .. } => {
- state.tool_uses_by_index.insert(
- index,
- RawToolUse {
- id,
- name,
- input_json: String::new(),
- },
- );
-
- return Some((None, state));
- }
- },
- Event::ContentBlockDelta { index, delta } => match delta {
- ContentDelta::TextDelta { text } => {
- return Some((Some(Ok(ResponseContent::Text { text })), state));
- }
- ContentDelta::InputJsonDelta { partial_json } => {
- if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
- tool_use.input_json.push_str(&partial_json);
- return Some((None, state));
- }
- }
- },
- Event::ContentBlockStop { index } => {
- if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
- return Some((
- Some(maybe!({
- Ok(ResponseContent::ToolUse {
- id: tool_use.id,
- name: tool_use.name,
- input: serde_json::Value::from_str(
- &tool_use.input_json,
- )
- .map_err(|err| anyhow!(err))?,
- })
- })),
- state,
- ));
- }
- }
- Event::Error { error } => {
- return Some((Some(Err(AnthropicError::ApiError(error))), state));
- }
- _ => {}
- },
- Err(err) => {
- return Some((Some(Err(err)), state));
- }
- }
- }
-
- None
- },
- )
- .filter_map(|event| async move { event })
-}
-
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
@@ -1999,6 +1999,11 @@ impl Context {
});
match event {
+ LanguageModelCompletionEvent::Stop(reason) => match reason {
+ language_model::StopReason::ToolUse => {}
+ language_model::StopReason::EndTurn => {}
+ language_model::StopReason::MaxTokens => {}
+ },
LanguageModelCompletionEvent::Text(chunk) => {
buffer.edit(
[(
@@ -55,10 +55,19 @@ pub struct LanguageModelCacheConfiguration {
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
+ Stop(StopReason),
Text(String),
ToolUse(LanguageModelToolUse),
}
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum StopReason {
+ EndTurn,
+ MaxTokens,
+ ToolUse,
+}
+
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: String,
@@ -112,6 +121,7 @@ pub trait LanguageModel: Send + Sync {
.filter_map(|result| async move {
match result {
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
+ Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Err(err) => Some(Err(err)),
}
@@ -3,11 +3,12 @@ use crate::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
-use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
-use anthropic::AnthropicError;
+use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
+use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
use anyhow::{anyhow, Context as _, Result};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
use editor::{Editor, EditorElement, EditorStyle};
+use futures::Stream;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
@@ -17,11 +18,13 @@ use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
-use util::ResultExt;
+use util::{maybe, ResultExt};
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
@@ -371,30 +374,9 @@ impl LanguageModel for AnthropicModel {
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
- Ok(anthropic::extract_content_from_events(response))
+ Ok(map_to_language_model_completion_events(response))
});
- async move {
- Ok(future
- .await?
- .map(|result| {
- result
- .map(|content| match content {
- anthropic::ResponseContent::Text { text } => {
- LanguageModelCompletionEvent::Text(text)
- }
- anthropic::ResponseContent::ToolUse { id, name, input } => {
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- id,
- name,
- input,
- })
- }
- })
- .map_err(|err| anyhow!(err))
- })
- .boxed())
- }
- .boxed()
+ async move { Ok(future.await?.boxed()) }.boxed()
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
@@ -443,6 +425,120 @@ impl LanguageModel for AnthropicModel {
}
}
+pub fn map_to_language_model_completion_events(
+ events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+ struct RawToolUse {
+ id: String,
+ name: String,
+ input_json: String,
+ }
+
+ struct State {
+ events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+ tool_uses_by_index: HashMap<usize, RawToolUse>,
+ }
+
+ futures::stream::unfold(
+ State {
+ events,
+ tool_uses_by_index: HashMap::default(),
+ },
+ |mut state| async move {
+ while let Some(event) = state.events.next().await {
+ match event {
+ Ok(event) => match event {
+ Event::ContentBlockStart {
+ index,
+ content_block,
+ } => match content_block {
+ ResponseContent::Text { text } => {
+ return Some((
+ Some(Ok(LanguageModelCompletionEvent::Text(text))),
+ state,
+ ));
+ }
+ ResponseContent::ToolUse { id, name, .. } => {
+ state.tool_uses_by_index.insert(
+ index,
+ RawToolUse {
+ id,
+ name,
+ input_json: String::new(),
+ },
+ );
+
+ return Some((None, state));
+ }
+ },
+ Event::ContentBlockDelta { index, delta } => match delta {
+ ContentDelta::TextDelta { text } => {
+ return Some((
+ Some(Ok(LanguageModelCompletionEvent::Text(text))),
+ state,
+ ));
+ }
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
+ tool_use.input_json.push_str(&partial_json);
+ return Some((None, state));
+ }
+ }
+ },
+ Event::ContentBlockStop { index } => {
+ if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
+ return Some((
+ Some(maybe!({
+ Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id,
+ name: tool_use.name,
+ input: serde_json::Value::from_str(
+ &tool_use.input_json,
+ )
+ .map_err(|err| anyhow!(err))?,
+ },
+ ))
+ })),
+ state,
+ ));
+ }
+ }
+ Event::MessageDelta { delta, .. } => {
+ if let Some(stop_reason) = delta.stop_reason.as_deref() {
+ let stop_reason = match stop_reason {
+ "end_turn" => StopReason::EndTurn,
+ "max_tokens" => StopReason::MaxTokens,
+ "tool_use" => StopReason::ToolUse,
+ _ => StopReason::EndTurn,
+ };
+
+ return Some((
+ Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
+ state,
+ ));
+ }
+ }
+ Event::Error { error } => {
+ return Some((
+ Some(Err(anyhow!(AnthropicError::ApiError(error)))),
+ state,
+ ));
+ }
+ _ => {}
+ },
+ Err(err) => {
+ return Some((Some(Err(anyhow!(err))), state));
+ }
+ }
+ }
+
+ None
+ },
+ )
+ .filter_map(|event| async move { event })
+}
+
struct ConfigurationView {
api_key_editor: View<Editor>,
state: gpui::Model<State>,
@@ -1,4 +1,5 @@
use super::open_ai::count_open_ai_tokens;
+use crate::provider::anthropic::map_to_language_model_completion_events;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
@@ -33,10 +34,7 @@ use std::{
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
-use crate::{
- LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
- LanguageModelToolUse,
-};
+use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
use super::anthropic::count_anthropic_tokens;
@@ -518,30 +516,11 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- Ok(anthropic::extract_content_from_events(Box::pin(
+ Ok(map_to_language_model_completion_events(Box::pin(
response_lines(response).map_err(AnthropicError::Other),
)))
});
- async move {
- Ok(future
- .await?
- .map(|result| {
- result
- .map(|content| match content {
- anthropic::ResponseContent::Text { text } => {
- LanguageModelCompletionEvent::Text(text)
- }
- anthropic::ResponseContent::ToolUse { id, name, input } => {
- LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse { id, name, input },
- )
- }
- })
- .map_err(|err| anyhow!(err))
- })
- .boxed())
- }
- .boxed()
+ async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();