Detailed changes
@@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
-use std::str::FromStr;
use std::time::Duration;
+use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@@ -241,6 +241,50 @@ pub fn extract_text_from_events(
})
}
+pub async fn extract_tool_args_from_events(
+ tool_name: String,
+ mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
+) -> Result<impl Send + Stream<Item = Result<String>>> {
+ let mut tool_use_index = None;
+ while let Some(event) = events.next().await {
+ if let Event::ContentBlockStart {
+ index,
+ content_block,
+ } = event?
+ {
+ if let Content::ToolUse { name, .. } = content_block {
+ if name == tool_name {
+ tool_use_index = Some(index);
+ break;
+ }
+ }
+ }
+ }
+
+ let Some(tool_use_index) = tool_use_index else {
+ return Err(anyhow!("tool not used"));
+ };
+
+ Ok(events.filter_map(move |event| {
+ let result = match event {
+ Err(error) => Some(Err(error)),
+ Ok(Event::ContentBlockDelta { index, delta }) => match delta {
+ ContentDelta::TextDelta { .. } => None,
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if index == tool_use_index {
+ Some(Ok(partial_json))
+ } else {
+ None
+ }
+ }
+ },
+ _ => None,
+ };
+
+ async move { result }
+ }))
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
@@ -1,6 +1,6 @@
use crate::{
- prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
- InlineAssistId, InlineAssistant, MessageId, MessageStatus,
+ prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InlineAssistId,
+ InlineAssistant, MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@@ -3342,7 +3342,7 @@ mod tests {
model
.as_fake()
- .respond_to_last_tool_use(Ok(serde_json::to_value(tool::WorkflowStepResolution {
+ .respond_to_last_tool_use(tool::WorkflowStepResolution {
step_title: "Title".into(),
suggestions: vec![tool::WorkflowSuggestion {
path: "/root/hello.rs".into(),
@@ -3352,8 +3352,7 @@ mod tests {
description: "Extract a greeting function".into(),
},
}],
- })
- .unwrap()));
+ });
// Wait for tool use to be processed.
cx.run_until_parked();
@@ -4084,44 +4083,4 @@ mod tool {
symbol: String,
},
}
-
- impl WorkflowSuggestionKind {
- pub fn symbol(&self) -> Option<&str> {
- match self {
- Self::Update { symbol, .. } => Some(symbol),
- Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
- Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
- Self::PrependChild { symbol, .. } => symbol.as_deref(),
- Self::AppendChild { symbol, .. } => symbol.as_deref(),
- Self::Delete { symbol } => Some(symbol),
- Self::Create { .. } => None,
- }
- }
-
- pub fn description(&self) -> Option<&str> {
- match self {
- Self::Update { description, .. } => Some(description),
- Self::Create { description } => Some(description),
- Self::InsertSiblingBefore { description, .. } => Some(description),
- Self::InsertSiblingAfter { description, .. } => Some(description),
- Self::PrependChild { description, .. } => Some(description),
- Self::AppendChild { description, .. } => Some(description),
- Self::Delete { .. } => None,
- }
- }
-
- pub fn initial_insertion(&self) -> Option<InitialInsertion> {
- match self {
- WorkflowSuggestionKind::InsertSiblingBefore { .. } => {
- Some(InitialInsertion::NewlineAfter)
- }
- WorkflowSuggestionKind::InsertSiblingAfter { .. } => {
- Some(InitialInsertion::NewlineBefore)
- }
- WorkflowSuggestionKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
- WorkflowSuggestionKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
- _ => None,
- }
- }
- }
}
@@ -1280,12 +1280,6 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
})
}
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
-pub enum InitialInsertion {
- NewlineBefore,
- NewlineAfter,
-}
-
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize);
@@ -351,10 +351,13 @@ impl Asset for ImageAsset {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
+ let mut body = String::from_utf8_lossy(&body).into_owned();
+ let first_line = body.lines().next().unwrap_or("").trim_end();
+ body.truncate(first_line.len());
return Err(ImageCacheError::BadStatus {
uri,
status: response.status(),
- body: String::from_utf8_lossy(&body).into_owned(),
+ body,
});
}
body
@@ -8,7 +8,7 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
-use futures::{future::BoxFuture, stream::BoxStream};
+use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
@@ -76,7 +76,7 @@ pub trait LanguageModel: Send + Sync {
description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>>;
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
#[cfg(any(test, feature = "test-support"))]
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
@@ -92,10 +92,11 @@ impl dyn LanguageModel {
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
- let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
+ let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {
- let response = request.await?;
- Ok(serde_json::from_value(response)?)
+ let stream = stream.await?;
+ let response = stream.try_collect::<String>().await?;
+ Ok(serde_json::from_str(&response)?)
}
}
}
@@ -7,7 +7,7 @@ use anthropic::AnthropicError;
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
-use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
View, WhiteSpace,
@@ -264,29 +264,6 @@ pub fn count_anthropic_tokens(
}
impl AnthropicModel {
- fn request_completion(
- &self,
- request: anthropic::Request,
- cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<anthropic::Response>> {
- let http_client = self.http_client.clone();
-
- let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
- async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
- anthropic::complete(http_client.as_ref(), &api_url, &api_key, request)
- .await
- .context("failed to retrieve completion")
- }
- .boxed()
- }
-
fn stream_completion(
&self,
request: anthropic::Request,
@@ -381,7 +358,7 @@ impl LanguageModel for AnthropicModel {
tool_description: String,
input_schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let mut request = request.into_anthropic(self.model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
@@ -392,25 +369,16 @@ impl LanguageModel for AnthropicModel {
input_schema,
}];
- let response = self.request_completion(request, cx);
+ let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
- response
- .content
- .into_iter()
- .find_map(|content| {
- if let anthropic::Content::ToolUse { name, input, .. } = content {
- if name == tool_name {
- Some(input)
- } else {
- None
- }
- } else {
- None
- }
- })
- .context("tool not used")
+ Ok(anthropic::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response.map_err(|e| anyhow!(e))),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
@@ -5,18 +5,21 @@ use crate::{
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anthropic::AnthropicError;
-use anyhow::{anyhow, bail, Context as _, Result};
+use anyhow::{anyhow, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, ZedPro};
-use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
+use futures::{
+ future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
+ TryStreamExt as _,
+};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::{
@@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: anthropic::Event = serde_json::from_str(&buffer)
- .context("failed to parse Anthropic event")?;
- Ok(Some((event, body)))
- }
- Err(err) => Err(AnthropicError::Other(err.into())),
- }
- });
-
- Ok(anthropic::extract_text_from_events(stream))
+ Ok(anthropic::extract_text_from_events(
+ response_lines(response).map_err(AnthropicError::Other),
+ ))
});
async move {
Ok(future
@@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: open_ai::ResponseStreamEvent =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(open_ai::extract_text_from_events(stream))
+ Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: google_ai::GenerateContentResponse =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(google_ai::extract_text_from_events(stream))
+ Ok(google_ai::extract_text_from_events(response_lines(
+ response,
+ )))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: open_ai::ResponseStreamEvent =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(open_ai::extract_text_from_events(stream))
+ Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel {
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+
match &self.model {
CloudModel::Anthropic(model) => {
- let client = self.client.clone();
let mut request = request.into_anthropic(model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
@@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
- let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
- let mut tool_use_index = None;
- let mut tool_input = String::new();
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- while body.read_line(&mut line).await? > 0 {
- let event: anthropic::Event = serde_json::from_str(&line)?;
- line.clear();
-
- match event {
- anthropic::Event::ContentBlockStart {
- content_block,
- index,
- } => {
- if let anthropic::Content::ToolUse { name, .. } = content_block
- {
- if name == tool_name {
- tool_use_index = Some(index);
- }
- }
- }
- anthropic::Event::ContentBlockDelta { index, delta } => match delta
- {
- anthropic::ContentDelta::TextDelta { .. } => {}
- anthropic::ContentDelta::InputJsonDelta { partial_json } => {
- if Some(index) == tool_use_index {
- tool_input.push_str(&partial_json);
- }
- }
- },
- anthropic::Event::ContentBlockStop { index } => {
- if Some(index) == tool_use_index {
- return Ok(serde_json::from_str(&tool_input)?);
- }
- }
- _ => {}
- }
- }
-
- if tool_use_index.is_some() {
- Err(anyhow!("tool content incomplete"))
- } else {
- Err(anyhow!("tool not used"))
- }
+ Ok(anthropic::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response_lines(response)),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
- let client = self.client.clone();
- let mut function = open_ai::FunctionDefinition {
- name: tool_name.clone(),
- description: None,
- parameters: None,
- };
- let func = open_ai::ToolDefinition::Function {
- function: function.clone(),
- };
- request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
- // Fill in description and params separately, as they're not needed for tool_choice field.
- function.description = Some(tool_description);
- function.parameters = Some(input_schema);
- request.tools = vec![open_ai::ToolDefinition::Function { function }];
+ request.tool_choice = Some(open_ai::ToolChoice::Other(
+ open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ },
+ },
+ ));
+ request.tools = vec![open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(input_schema),
+ },
+ }];
- let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- let mut load_state = None;
-
- while body.read_line(&mut line).await? > 0 {
- let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
- line.clear();
-
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
-
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
+ Ok(open_ai::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response_lines(response)),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
@@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Zed(model) => {
// All Zed models are OpenAI-based at the time of writing.
let mut request = request.into_open_ai(model.id().into());
- let client = self.client.clone();
- let mut function = open_ai::FunctionDefinition {
- name: tool_name.clone(),
- description: None,
- parameters: None,
- };
- let func = open_ai::ToolDefinition::Function {
- function: function.clone(),
- };
- request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
- // Fill in description and params separately, as they're not needed for tool_choice field.
- function.description = Some(tool_description);
- function.parameters = Some(input_schema);
- request.tools = vec![open_ai::ToolDefinition::Function { function }];
+ request.tool_choice = Some(open_ai::ToolChoice::Other(
+ open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ },
+ },
+ ));
+ request.tools = vec![open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(input_schema),
+ },
+ }];
- let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- let mut load_state = None;
-
- while body.read_line(&mut line).await? > 0 {
- let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
- line.clear();
-
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
+ Ok(open_ai::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response_lines(response)),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
@@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel {
}
}
+fn response_lines<T: DeserializeOwned>(
+ response: Response<AsyncBody>,
+) -> impl Stream<Item = Result<T>> {
+ futures::stream::try_unfold(
+ (String::new(), BufReader::new(response.into_body())),
+ move |(mut line, mut body)| async {
+ match body.read_line(&mut line).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: T = serde_json::from_str(&line)?;
+ line.clear();
+ Ok(Some((event, (line, body))))
+ }
+ Err(e) => Err(e.into()),
+ }
+ },
+ )
+}
+
impl LlmApiToken {
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;
@@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel {
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
@@ -3,16 +3,11 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
-use anyhow::Context as _;
-use futures::{
- channel::{mpsc, oneshot},
- future::BoxFuture,
- stream::BoxStream,
- FutureExt, StreamExt,
-};
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
use parking_lot::Mutex;
+use serde::Serialize;
use std::sync::Arc;
use ui::WindowContext;
@@ -90,7 +85,7 @@ pub struct ToolUseRequest {
#[derive(Default)]
pub struct FakeLanguageModel {
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
- current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
+ current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
}
impl FakeLanguageModel {
@@ -130,25 +125,11 @@ impl FakeLanguageModel {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
- pub fn respond_to_tool_use(
- &self,
- tool_call: &ToolUseRequest,
- response: Result<serde_json::Value>,
- ) {
- let mut current_tool_call_txs = self.current_tool_use_txs.lock();
- if let Some(index) = current_tool_call_txs
- .iter()
- .position(|(call, _)| call == tool_call)
- {
- let (_, tx) = current_tool_call_txs.remove(index);
- tx.send(response).unwrap();
- }
- }
-
- pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
+ pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
+ let response = serde_json::to_string(&response).unwrap();
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
let (_, tx) = current_tool_call_txs.pop().unwrap();
- tx.send(response).unwrap();
+ tx.unbounded_send(response).unwrap();
}
}
@@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel {
description: String,
schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
- let (tx, rx) = oneshot::channel();
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let (tx, rx) = mpsc::unbounded();
let tool_call = ToolUseRequest {
request,
name,
@@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel {
schema,
};
self.current_tool_use_txs.lock().push((tool_call, tx));
- async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
+ async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {
@@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel {
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
@@ -6,7 +6,6 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall,
};
-use serde_json::Value;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
@@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel {
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
use ollama::{OllamaFunctionTool, OllamaTool};
let function = OllamaFunctionTool {
name: tool_name.clone(),
@@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel {
self.request_limiter
.run(async move {
let response = response.await?;
- let ChatMessage::Assistant {
- tool_calls,
- content,
- } = response.message
- else {
+ let ChatMessage::Assistant { tool_calls, .. } = response.message else {
bail!("message does not have an assistant role");
};
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
for call in tool_calls {
let OllamaToolCall::Function(function) = call;
if function.name == tool_name {
- return Ok(function.arguments);
+ return Ok(futures::stream::once(async move {
+ Ok(function.arguments.to_string())
+ })
+ .boxed());
}
}
- } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
- // Parse content as arguments.
- return Ok(args);
} else {
bail!("assistant message does not have any tool calls");
};
@@ -1,4 +1,4 @@
-use anyhow::{anyhow, bail, Result};
+use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@@ -243,6 +243,7 @@ impl OpenAiLanguageModel {
async move { Ok(future.await?.boxed()) }.boxed()
}
}
+
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel {
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_open_ai(self.model.id().into());
- let mut function = FunctionDefinition {
- name: tool_name.clone(),
- description: None,
- parameters: None,
- };
- let func = ToolDefinition::Function {
- function: function.clone(),
- };
- request.tool_choice = Some(ToolChoice::Other(func.clone()));
- // Fill in description and params separately, as they're not needed for tool_choice field.
- function.description = Some(tool_description);
- function.parameters = Some(schema);
- request.tools = vec![ToolDefinition::Function { function }];
+ request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ },
+ }));
+ request.tools = vec![ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(schema),
+ },
+ }];
+
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
- let mut response = response.await?;
-
- // Call arguments are gonna be streamed in over multiple chunks.
- let mut load_state = None;
- while let Some(Ok(part)) = response.next().await {
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
+ let response = response.await?;
+ Ok(
+ open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
+ .await?
+ .boxed(),
+ )
})
.boxed()
}
@@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use serde_json::Value;
+use serde_json::{value::RawValue, Value};
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -92,7 +92,7 @@ impl Model {
}
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
@@ -107,16 +107,16 @@ pub enum ChatMessage {
},
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum OllamaToolCall {
Function(OllamaFunctionCall),
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct OllamaFunctionCall {
pub name: String,
- pub arguments: Value,
+ pub arguments: Box<RawValue>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::Value;
-use std::{convert::TryFrom, future::Future, time::Duration};
+use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
use strum::EnumIter;
pub use supported_countries::*;
@@ -384,6 +384,57 @@ pub fn embed<'a>(
}
}
+pub async fn extract_tool_args_from_events(
+ tool_name: String,
+ mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+) -> Result<impl Send + Stream<Item = Result<String>>> {
+ let mut tool_use_index = None;
+ let mut first_chunk = None;
+ while let Some(event) = events.next().await {
+ let call = event?.choices.into_iter().find_map(|choice| {
+ choice.delta.tool_calls?.into_iter().find_map(|call| {
+ if call.function.as_ref()?.name.as_deref()? == tool_name {
+ Some(call)
+ } else {
+ None
+ }
+ })
+ });
+ if let Some(call) = call {
+ tool_use_index = Some(call.index);
+ first_chunk = call.function.and_then(|func| func.arguments);
+ break;
+ }
+ }
+
+ let Some(tool_use_index) = tool_use_index else {
+ return Err(anyhow!("tool not used"));
+ };
+
+ Ok(events.filter_map(move |event| {
+ let result = match event {
+ Err(error) => Some(Err(error)),
+ Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
+ choice.delta.tool_calls?.into_iter().find_map(|call| {
+ if call.index == tool_use_index {
+ let func = call.function?;
+ let mut arguments = func.arguments?;
+ if let Some(mut first_chunk) = first_chunk.take() {
+ first_chunk.push_str(&arguments);
+ arguments = first_chunk
+ }
+ Some(Ok(arguments))
+ } else {
+ None
+ }
+ })
+ }),
+ };
+
+ async move { result }
+ }))
+}
+
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {