Cargo.lock 🔗
@@ -243,6 +243,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
+ "collections",
"futures 0.3.30",
"http_client",
"isahc",
Marshall Bowers created
This PR adjusts the approach we use to encoding tool uses in the
completion response to use a structured format rather than simply
injecting it into the response stream as text.
In #17170 we would encode the tool uses as XML and insert them as text.
This would require then re-parsing the tool uses out of the buffer in
order to use them.
The approach taken in this PR is to make `stream_completion` return a
stream of `LanguageModelCompletionEvent`s. Each of these events can be
either text, or a tool use.
A new `stream_completion_text` method has been added to `LanguageModel`
for scenarios where we only care about textual content (currently,
everywhere that isn't the Assistant context editor).
Release Notes:
- N/A
Cargo.lock | 1
crates/anthropic/Cargo.toml | 1
crates/anthropic/src/anthropic.rs | 92 ++++++++--------
crates/assistant/src/context.rs | 48 ++++++-
crates/assistant/src/inline_assistant.rs | 2
crates/assistant/src/terminal_inline_assistant.rs | 2
crates/language_model/src/language_model.rs | 41 ++++++
crates/language_model/src/provider/anthropic.rs | 20 +++
crates/language_model/src/provider/cloud.rs | 46 ++++++-
crates/language_model/src/provider/copilot_chat.rs | 12 +
crates/language_model/src/provider/fake.rs | 15 +
crates/language_model/src/provider/google.rs | 14 ++
crates/language_model/src/provider/ollama.rs | 11 +
crates/language_model/src/provider/open_ai.rs | 13 +
14 files changed, 235 insertions(+), 83 deletions(-)
@@ -243,6 +243,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
+ "collections",
"futures 0.3.30",
"http_client",
"isahc",
@@ -18,6 +18,7 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
+collections.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true
@@ -1,17 +1,19 @@
mod supported_countries;
+use std::time::Duration;
+use std::{pin::Pin, str::FromStr};
+
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
+use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
-use std::time::Duration;
-use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
-use util::ResultExt as _;
+use util::{maybe, ResultExt as _};
pub use supported_countries::*;
@@ -332,19 +334,22 @@ pub async fn stream_completion_with_rate_limit_info(
pub fn extract_content_from_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
-) -> impl Stream<Item = Result<String, AnthropicError>> {
+) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
+ struct RawToolUse {
+ id: String,
+ name: String,
+ input_json: String,
+ }
+
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
- current_tool_use_index: Option<usize>,
+ tool_uses_by_index: HashMap<usize, RawToolUse>,
}
- const INDENT: &str = " ";
- const NEWLINE: char = '\n';
-
futures::stream::unfold(
State {
events,
- current_tool_use_index: None,
+ tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
@@ -355,62 +360,56 @@ pub fn extract_content_from_events(
content_block,
} => match content_block {
ResponseContent::Text { text } => {
- return Some((Ok(text), state));
+ return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ResponseContent::ToolUse { id, name, .. } => {
- state.current_tool_use_index = Some(index);
-
- let mut text = String::new();
- text.push(NEWLINE);
-
- text.push_str("<tool_use>");
- text.push(NEWLINE);
-
- text.push_str(INDENT);
- text.push_str("<id>");
- text.push_str(&id);
- text.push_str("</id>");
- text.push(NEWLINE);
-
- text.push_str(INDENT);
- text.push_str("<name>");
- text.push_str(&name);
- text.push_str("</name>");
- text.push(NEWLINE);
-
- text.push_str(INDENT);
- text.push_str("<input>");
-
- return Some((Ok(text), state));
+ state.tool_uses_by_index.insert(
+ index,
+ RawToolUse {
+ id,
+ name,
+ input_json: String::new(),
+ },
+ );
+
+ return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
- return Some((Ok(text), state));
+ return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ContentDelta::InputJsonDelta { partial_json } => {
- if Some(index) == state.current_tool_use_index {
- return Some((Ok(partial_json), state));
+ if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
+ tool_use.input_json.push_str(&partial_json);
+ return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
- if Some(index) == state.current_tool_use_index.take() {
- let mut text = String::new();
- text.push_str("</input>");
- text.push(NEWLINE);
- text.push_str("</tool_use>");
-
- return Some((Ok(text), state));
+ if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
+ return Some((
+ Some(maybe!({
+ Ok(ResponseContent::ToolUse {
+ id: tool_use.id,
+ name: tool_use.name,
+ input: serde_json::Value::from_str(
+ &tool_use.input_json,
+ )
+ .map_err(|err| anyhow!(err))?,
+ })
+ })),
+ state,
+ ));
}
}
Event::Error { error } => {
- return Some((Err(AnthropicError::ApiError(error)), state));
+ return Some((Some(Err(AnthropicError::ApiError(error))), state));
}
_ => {}
},
Err(err) => {
- return Some((Err(err), state));
+ return Some((Some(Err(err)), state));
}
}
}
@@ -418,6 +417,7 @@ pub fn extract_content_from_events(
None
},
)
+ .filter_map(|event| async move { event })
}
pub async fn extract_tool_args_from_events(
@@ -25,8 +25,9 @@ use gpui::{
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{
- LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
- LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
+ LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
+ LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ MessageContent, Role,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
@@ -1950,13 +1951,13 @@ impl Context {
let mut response_latency = None;
let stream_completion = async {
let request_start = Instant::now();
- let mut chunks = stream.await?;
+ let mut events = stream.await?;
- while let Some(chunk) = chunks.next().await {
+ while let Some(event) = events.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
- let chunk = chunk?;
+ let event = event?;
this.update(&mut cx, |this, cx| {
let message_ix = this
@@ -1970,11 +1971,36 @@ impl Context {
.map_or(buffer.len(), |message| {
message.start.to_offset(buffer).saturating_sub(1)
});
- buffer.edit(
- [(message_old_end_offset..message_old_end_offset, chunk)],
- None,
- cx,
- );
+
+ match event {
+ LanguageModelCompletionEvent::Text(chunk) => {
+ buffer.edit(
+ [(
+ message_old_end_offset..message_old_end_offset,
+ chunk,
+ )],
+ None,
+ cx,
+ );
+ }
+ LanguageModelCompletionEvent::ToolUse(tool_use) => {
+ let mut text = String::new();
+ text.push('\n');
+ text.push_str(
+ &serde_json::to_string_pretty(&tool_use)
+ .expect("failed to serialize tool use to JSON"),
+ );
+
+ buffer.edit(
+ [(
+ message_old_end_offset..message_old_end_offset,
+ text,
+ )],
+ None,
+ cx,
+ );
+ }
+ }
});
cx.emit(ContextEvent::StreamedCompletion);
@@ -2406,7 +2432,7 @@ impl Context {
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
- let stream = model.stream_completion(request, &cx);
+ let stream = model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut replaced = !replace_old;
@@ -2344,7 +2344,7 @@ impl Codegen {
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
let chunks =
- cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
+ cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local()
};
self.handle_stream(telemetry_id, edit_range, chunks, cx);
@@ -1010,7 +1010,7 @@ impl Codegen {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(|this, mut cx| async move {
let model_telemetry_id = model.telemetry_id();
- let response = model.stream_completion(prompt, &cx).await;
+ let response = model.stream_completion_text(prompt, &cx).await;
let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
@@ -8,7 +8,8 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
-use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
+use futures::FutureExt;
+use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
@@ -51,6 +52,20 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize,
}
+/// A completion event from a language model.
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+pub enum LanguageModelCompletionEvent {
+ Text(String),
+ ToolUse(LanguageModelToolUse),
+}
+
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+pub struct LanguageModelToolUse {
+ pub id: String,
+ pub name: String,
+ pub input: serde_json::Value,
+}
+
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
@@ -82,7 +97,29 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
+
+ fn stream_completion_text(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let events = self.stream_completion(request, cx);
+
+ async move {
+ Ok(events
+ .await?
+ .filter_map(|result| async move {
+ match result {
+ Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
+ Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
+ Err(err) => Some(Err(err)),
+ }
+ })
+ .boxed())
+ }
+ .boxed()
+ }
fn use_any_tool(
&self,
@@ -3,6 +3,7 @@ use crate::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
+use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
use anthropic::AnthropicError;
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
@@ -364,7 +365,7 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request =
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
let request = self.stream_completion(request, cx);
@@ -375,7 +376,22 @@ impl LanguageModel for AnthropicModel {
async move {
Ok(future
.await?
- .map(|result| result.map_err(|err| anyhow!(err)))
+ .map(|result| {
+ result
+ .map(|content| match content {
+ anthropic::ResponseContent::Text { text } => {
+ LanguageModelCompletionEvent::Text(text)
+ }
+ anthropic::ResponseContent::ToolUse { id, name, input } => {
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ id,
+ name,
+ input,
+ })
+ }
+ })
+ .map_err(|err| anyhow!(err))
+ })
.boxed())
}
.boxed()
@@ -33,7 +33,10 @@ use std::{
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
-use crate::{LanguageModelAvailability, LanguageModelProvider};
+use crate::{
+ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
+ LanguageModelToolUse,
+};
use super::anthropic::count_anthropic_tokens;
@@ -496,7 +499,7 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
@@ -522,7 +525,20 @@ impl LanguageModel for CloudLanguageModel {
async move {
Ok(future
.await?
- .map(|result| result.map_err(|err| anyhow!(err)))
+ .map(|result| {
+ result
+ .map(|content| match content {
+ anthropic::ResponseContent::Text { text } => {
+ LanguageModelCompletionEvent::Text(text)
+ }
+ anthropic::ResponseContent::ToolUse { id, name, input } => {
+ LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse { id, name, input },
+ )
+ }
+ })
+ .map_err(|err| anyhow!(err))
+ })
.boxed())
}
.boxed()
@@ -546,7 +562,13 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ Ok(future
+ .await?
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
@@ -569,7 +591,13 @@ impl LanguageModel for CloudLanguageModel {
response,
)))
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ Ok(future
+ .await?
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
CloudModel::Zed(model) => {
let client = self.client.clone();
@@ -591,7 +619,13 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ Ok(future
+ .await?
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
}
}
@@ -24,11 +24,11 @@ use ui::{
};
use crate::settings::AllLanguageModelSettings;
-use crate::LanguageModelProviderState;
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
};
+use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
use super::open_ai::count_open_ai_tokens;
@@ -192,7 +192,7 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
@@ -243,7 +243,13 @@ impl LanguageModel for CopilotChatLanguageModel {
}).await
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ Ok(future
+ .await?
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
fn use_any_tool(
@@ -1,7 +1,7 @@
use crate::{
- LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest,
+ LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest,
};
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
@@ -170,10 +170,15 @@ impl LanguageModel for FakeLanguageModel {
&self,
request: LanguageModelRequest,
_: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
- async move { Ok(rx.map(Ok).boxed()) }.boxed()
+ async move {
+ Ok(rx
+ .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
+ .boxed())
+ }
+ .boxed()
}
fn use_any_tool(
@@ -17,6 +17,7 @@ use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
+use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@@ -281,7 +282,10 @@ impl LanguageModel for GoogleLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
+ > {
let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone();
@@ -299,7 +303,13 @@ impl LanguageModel for GoogleLanguageModel {
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ Ok(future
+ .await?
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
fn use_any_tool(
@@ -13,6 +13,7 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
+use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@@ -302,7 +303,7 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@@ -335,7 +336,13 @@ impl LanguageModel for OllamaLanguageModel {
Ok(stream)
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ Ok(future
+ .await?
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
fn use_any_tool(
@@ -19,6 +19,7 @@ use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
+use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@@ -293,10 +294,18 @@ impl LanguageModel for OpenAiLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
+ > {
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
- async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
+ async move {
+ Ok(open_ai::extract_text_from_events(completions.await?)
+ .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .boxed())
+ }
+ .boxed()
}
fn use_any_tool(