assistant: Propagate LLM stop reason upwards (#17358)

Marshall Bowers created

This PR makes it so we propagate the `stop_reason` from Anthropic up to
the Assistant so that we can take action based on it.

The `extract_content_from_events` function was moved from `anthropic` to
the `anthropic` module in `language_model` since it is more useful if it
is able to name the `LanguageModelCompletionEvent` type, as otherwise
we'd need an additional layer of plumbing.

Release Notes:

- N/A

Change summary

Cargo.lock                                      |   1 
crates/anthropic/Cargo.toml                     |   1 
crates/anthropic/src/anthropic.rs               |  91 -----------
crates/assistant/src/context.rs                 |   5 
crates/language_model/src/language_model.rs     |  10 +
crates/language_model/src/provider/anthropic.rs | 150 +++++++++++++++---
crates/language_model/src/provider/cloud.rs     |  29 ---
7 files changed, 143 insertions(+), 144 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -243,7 +243,6 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "chrono",
- "collections",
  "futures 0.3.30",
  "http_client",
  "isahc",

crates/anthropic/Cargo.toml 🔗

@@ -18,7 +18,6 @@ path = "src/anthropic.rs"
 [dependencies]
 anyhow.workspace = true
 chrono.workspace = true
-collections.workspace = true
 futures.workspace = true
 http_client.workspace = true
 isahc.workspace = true

crates/anthropic/src/anthropic.rs 🔗

@@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
 
 use anyhow::{anyhow, Context, Result};
 use chrono::{DateTime, Utc};
-use collections::HashMap;
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
@@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
 use serde::{Deserialize, Serialize};
 use strum::{EnumIter, EnumString};
 use thiserror::Error;
-use util::{maybe, ResultExt as _};
+use util::ResultExt as _;
 
 pub use supported_countries::*;
 
@@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
     }
 }
 
-pub fn extract_content_from_events(
-    events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
-) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
-    struct RawToolUse {
-        id: String,
-        name: String,
-        input_json: String,
-    }
-
-    struct State {
-        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
-        tool_uses_by_index: HashMap<usize, RawToolUse>,
-    }
-
-    futures::stream::unfold(
-        State {
-            events,
-            tool_uses_by_index: HashMap::default(),
-        },
-        |mut state| async move {
-            while let Some(event) = state.events.next().await {
-                match event {
-                    Ok(event) => match event {
-                        Event::ContentBlockStart {
-                            index,
-                            content_block,
-                        } => match content_block {
-                            ResponseContent::Text { text } => {
-                                return Some((Some(Ok(ResponseContent::Text { text })), state));
-                            }
-                            ResponseContent::ToolUse { id, name, .. } => {
-                                state.tool_uses_by_index.insert(
-                                    index,
-                                    RawToolUse {
-                                        id,
-                                        name,
-                                        input_json: String::new(),
-                                    },
-                                );
-
-                                return Some((None, state));
-                            }
-                        },
-                        Event::ContentBlockDelta { index, delta } => match delta {
-                            ContentDelta::TextDelta { text } => {
-                                return Some((Some(Ok(ResponseContent::Text { text })), state));
-                            }
-                            ContentDelta::InputJsonDelta { partial_json } => {
-                                if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
-                                    tool_use.input_json.push_str(&partial_json);
-                                    return Some((None, state));
-                                }
-                            }
-                        },
-                        Event::ContentBlockStop { index } => {
-                            if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
-                                return Some((
-                                    Some(maybe!({
-                                        Ok(ResponseContent::ToolUse {
-                                            id: tool_use.id,
-                                            name: tool_use.name,
-                                            input: serde_json::Value::from_str(
-                                                &tool_use.input_json,
-                                            )
-                                            .map_err(|err| anyhow!(err))?,
-                                        })
-                                    })),
-                                    state,
-                                ));
-                            }
-                        }
-                        Event::Error { error } => {
-                            return Some((Some(Err(AnthropicError::ApiError(error))), state));
-                        }
-                        _ => {}
-                    },
-                    Err(err) => {
-                        return Some((Some(Err(err)), state));
-                    }
-                }
-            }
-
-            None
-        },
-    )
-    .filter_map(|event| async move { event })
-}
-
 pub async fn extract_tool_args_from_events(
     tool_name: String,
     mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,

crates/assistant/src/context.rs 🔗

@@ -1999,6 +1999,11 @@ impl Context {
                                     });
 
                                 match event {
+                                    LanguageModelCompletionEvent::Stop(reason) => match reason {
+                                        language_model::StopReason::ToolUse => {}
+                                        language_model::StopReason::EndTurn => {}
+                                        language_model::StopReason::MaxTokens => {}
+                                    },
                                     LanguageModelCompletionEvent::Text(chunk) => {
                                         buffer.edit(
                                             [(

crates/language_model/src/language_model.rs 🔗

@@ -55,10 +55,19 @@ pub struct LanguageModelCacheConfiguration {
 /// A completion event from a language model.
 #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 pub enum LanguageModelCompletionEvent {
+    Stop(StopReason),
     Text(String),
     ToolUse(LanguageModelToolUse),
 }
 
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum StopReason {
+    EndTurn,
+    MaxTokens,
+    ToolUse,
+}
+
 #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 pub struct LanguageModelToolUse {
     pub id: String,
@@ -112,6 +121,7 @@ pub trait LanguageModel: Send + Sync {
                 .filter_map(|result| async move {
                     match result {
                         Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
+                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
                         Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
                         Err(err) => Some(Err(err)),
                     }

crates/language_model/src/provider/anthropic.rs 🔗

@@ -3,11 +3,12 @@ use crate::{
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 };
-use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
-use anthropic::AnthropicError;
+use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
+use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
 use anyhow::{anyhow, Context as _, Result};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
 use editor::{Editor, EditorElement, EditorStyle};
+use futures::Stream;
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
 use gpui::{
     AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
@@ -17,11 +18,13 @@ use http_client::HttpClient;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr;
 use std::{sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::{prelude::*, Icon, IconName, Tooltip};
-use util::ResultExt;
+use util::{maybe, ResultExt};
 
 const PROVIDER_ID: &str = "anthropic";
 const PROVIDER_NAME: &str = "Anthropic";
@@ -371,30 +374,9 @@ impl LanguageModel for AnthropicModel {
         let request = self.stream_completion(request, cx);
         let future = self.request_limiter.stream(async move {
             let response = request.await.map_err(|err| anyhow!(err))?;
-            Ok(anthropic::extract_content_from_events(response))
+            Ok(map_to_language_model_completion_events(response))
         });
-        async move {
-            Ok(future
-                .await?
-                .map(|result| {
-                    result
-                        .map(|content| match content {
-                            anthropic::ResponseContent::Text { text } => {
-                                LanguageModelCompletionEvent::Text(text)
-                            }
-                            anthropic::ResponseContent::ToolUse { id, name, input } => {
-                                LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
-                                    id,
-                                    name,
-                                    input,
-                                })
-                            }
-                        })
-                        .map_err(|err| anyhow!(err))
-                })
-                .boxed())
-        }
-        .boxed()
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 
     fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
@@ -443,6 +425,120 @@ impl LanguageModel for AnthropicModel {
     }
 }
 
+pub fn map_to_language_model_completion_events(
+    events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+    struct RawToolUse {
+        id: String,
+        name: String,
+        input_json: String,
+    }
+
+    struct State {
+        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+        tool_uses_by_index: HashMap<usize, RawToolUse>,
+    }
+
+    futures::stream::unfold(
+        State {
+            events,
+            tool_uses_by_index: HashMap::default(),
+        },
+        |mut state| async move {
+            while let Some(event) = state.events.next().await {
+                match event {
+                    Ok(event) => match event {
+                        Event::ContentBlockStart {
+                            index,
+                            content_block,
+                        } => match content_block {
+                            ResponseContent::Text { text } => {
+                                return Some((
+                                    Some(Ok(LanguageModelCompletionEvent::Text(text))),
+                                    state,
+                                ));
+                            }
+                            ResponseContent::ToolUse { id, name, .. } => {
+                                state.tool_uses_by_index.insert(
+                                    index,
+                                    RawToolUse {
+                                        id,
+                                        name,
+                                        input_json: String::new(),
+                                    },
+                                );
+
+                                return Some((None, state));
+                            }
+                        },
+                        Event::ContentBlockDelta { index, delta } => match delta {
+                            ContentDelta::TextDelta { text } => {
+                                return Some((
+                                    Some(Ok(LanguageModelCompletionEvent::Text(text))),
+                                    state,
+                                ));
+                            }
+                            ContentDelta::InputJsonDelta { partial_json } => {
+                                if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
+                                    tool_use.input_json.push_str(&partial_json);
+                                    return Some((None, state));
+                                }
+                            }
+                        },
+                        Event::ContentBlockStop { index } => {
+                            if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
+                                return Some((
+                                    Some(maybe!({
+                                        Ok(LanguageModelCompletionEvent::ToolUse(
+                                            LanguageModelToolUse {
+                                                id: tool_use.id,
+                                                name: tool_use.name,
+                                                input: serde_json::Value::from_str(
+                                                    &tool_use.input_json,
+                                                )
+                                                .map_err(|err| anyhow!(err))?,
+                                            },
+                                        ))
+                                    })),
+                                    state,
+                                ));
+                            }
+                        }
+                        Event::MessageDelta { delta, .. } => {
+                            if let Some(stop_reason) = delta.stop_reason.as_deref() {
+                                let stop_reason = match stop_reason {
+                                    "end_turn" => StopReason::EndTurn,
+                                    "max_tokens" => StopReason::MaxTokens,
+                                    "tool_use" => StopReason::ToolUse,
+                                    _ => StopReason::EndTurn,
+                                };
+
+                                return Some((
+                                    Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
+                                    state,
+                                ));
+                            }
+                        }
+                        Event::Error { error } => {
+                            return Some((
+                                Some(Err(anyhow!(AnthropicError::ApiError(error)))),
+                                state,
+                            ));
+                        }
+                        _ => {}
+                    },
+                    Err(err) => {
+                        return Some((Some(Err(anyhow!(err))), state));
+                    }
+                }
+            }
+
+            None
+        },
+    )
+    .filter_map(|event| async move { event })
+}
+
 struct ConfigurationView {
     api_key_editor: View<Editor>,
     state: gpui::Model<State>,

crates/language_model/src/provider/cloud.rs 🔗

@@ -1,4 +1,5 @@
 use super::open_ai::count_open_ai_tokens;
+use crate::provider::anthropic::map_to_language_model_completion_events;
 use crate::{
     settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
     LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
@@ -33,10 +34,7 @@ use std::{
 use strum::IntoEnumIterator;
 use ui::{prelude::*, TintColor};
 
-use crate::{
-    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
-    LanguageModelToolUse,
-};
+use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
 
 use super::anthropic::count_anthropic_tokens;
 
@@ -518,30 +516,11 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
-                    Ok(anthropic::extract_content_from_events(Box::pin(
+                    Ok(map_to_language_model_completion_events(Box::pin(
                         response_lines(response).map_err(AnthropicError::Other),
                     )))
                 });
-                async move {
-                    Ok(future
-                        .await?
-                        .map(|result| {
-                            result
-                                .map(|content| match content {
-                                    anthropic::ResponseContent::Text { text } => {
-                                        LanguageModelCompletionEvent::Text(text)
-                                    }
-                                    anthropic::ResponseContent::ToolUse { id, name, input } => {
-                                        LanguageModelCompletionEvent::ToolUse(
-                                            LanguageModelToolUse { id, name, input },
-                                        )
-                                    }
-                                })
-                                .map_err(|err| anyhow!(err))
-                        })
-                        .boxed())
-                }
-                .boxed()
+                async move { Ok(future.await?.boxed()) }.boxed()
             }
             CloudModel::OpenAi(model) => {
                 let client = self.client.clone();