Make LanguageModel::use_any_tool return a stream of chunks (#16262)

Max Brunsfeld created

This PR is a refactor to pave the way for allowing the user to view and
edit workflow step resolutions. I've made tool calls work more like
normal streaming completions for all providers. The `use_any_tool`
method returns a stream of strings (which contain chunks of JSON). I've
also done some minor cleanup of language model providers in general,
removing the duplication around handling streaming responses.

Release Notes:

- N/A

Change summary

crates/anthropic/src/anthropic.rs                  |  46 ++
crates/assistant/src/context.rs                    |  49 --
crates/assistant/src/inline_assistant.rs           |   6 
crates/gpui/src/elements/img.rs                    |   5 
crates/language_model/src/language_model.rs        |  11 
crates/language_model/src/provider/anthropic.rs    |  50 --
crates/language_model/src/provider/cloud.rs        | 295 ++++-----------
crates/language_model/src/provider/copilot_chat.rs |   2 
crates/language_model/src/provider/fake.rs         |  37 -
crates/language_model/src/provider/google.rs       |   2 
crates/language_model/src/provider/ollama.rs       |  17 
crates/language_model/src/provider/open_ai.rs      |  70 +--
crates/ollama/src/ollama.rs                        |  10 
crates/open_ai/src/open_ai.rs                      |  53 ++
14 files changed, 253 insertions(+), 400 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
-use std::str::FromStr;
 use std::time::Duration;
+use std::{pin::Pin, str::FromStr};
 use strum::{EnumIter, EnumString};
 use thiserror::Error;
 
@@ -241,6 +241,50 @@ pub fn extract_text_from_events(
     })
 }
 
+pub async fn extract_tool_args_from_events(
+    tool_name: String,
+    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
+) -> Result<impl Send + Stream<Item = Result<String>>> {
+    let mut tool_use_index = None;
+    while let Some(event) = events.next().await {
+        if let Event::ContentBlockStart {
+            index,
+            content_block,
+        } = event?
+        {
+            if let Content::ToolUse { name, .. } = content_block {
+                if name == tool_name {
+                    tool_use_index = Some(index);
+                    break;
+                }
+            }
+        }
+    }
+
+    let Some(tool_use_index) = tool_use_index else {
+        return Err(anyhow!("tool not used"));
+    };
+
+    Ok(events.filter_map(move |event| {
+        let result = match event {
+            Err(error) => Some(Err(error)),
+            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
+                ContentDelta::TextDelta { .. } => None,
+                ContentDelta::InputJsonDelta { partial_json } => {
+                    if index == tool_use_index {
+                        Some(Ok(partial_json))
+                    } else {
+                        None
+                    }
+                }
+            },
+            _ => None,
+        };
+
+        async move { result }
+    }))
+}
+
 #[derive(Debug, Serialize, Deserialize)]
 pub struct Message {
     pub role: Role,

crates/assistant/src/context.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
-    prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
-    InlineAssistId, InlineAssistant, MessageId, MessageStatus,
+    prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InlineAssistId,
+    InlineAssistant, MessageId, MessageStatus,
 };
 use anyhow::{anyhow, Context as _, Result};
 use assistant_slash_command::{
@@ -3342,7 +3342,7 @@ mod tests {
 
         model
             .as_fake()
-            .respond_to_last_tool_use(Ok(serde_json::to_value(tool::WorkflowStepResolution {
+            .respond_to_last_tool_use(tool::WorkflowStepResolution {
                 step_title: "Title".into(),
                 suggestions: vec![tool::WorkflowSuggestion {
                     path: "/root/hello.rs".into(),
@@ -3352,8 +3352,7 @@ mod tests {
                         description: "Extract a greeting function".into(),
                     },
                 }],
-            })
-            .unwrap()));
+            });
 
         // Wait for tool use to be processed.
         cx.run_until_parked();
@@ -4084,44 +4083,4 @@ mod tool {
             symbol: String,
         },
     }
-
-    impl WorkflowSuggestionKind {
-        pub fn symbol(&self) -> Option<&str> {
-            match self {
-                Self::Update { symbol, .. } => Some(symbol),
-                Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
-                Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
-                Self::PrependChild { symbol, .. } => symbol.as_deref(),
-                Self::AppendChild { symbol, .. } => symbol.as_deref(),
-                Self::Delete { symbol } => Some(symbol),
-                Self::Create { .. } => None,
-            }
-        }
-
-        pub fn description(&self) -> Option<&str> {
-            match self {
-                Self::Update { description, .. } => Some(description),
-                Self::Create { description } => Some(description),
-                Self::InsertSiblingBefore { description, .. } => Some(description),
-                Self::InsertSiblingAfter { description, .. } => Some(description),
-                Self::PrependChild { description, .. } => Some(description),
-                Self::AppendChild { description, .. } => Some(description),
-                Self::Delete { .. } => None,
-            }
-        }
-
-        pub fn initial_insertion(&self) -> Option<InitialInsertion> {
-            match self {
-                WorkflowSuggestionKind::InsertSiblingBefore { .. } => {
-                    Some(InitialInsertion::NewlineAfter)
-                }
-                WorkflowSuggestionKind::InsertSiblingAfter { .. } => {
-                    Some(InitialInsertion::NewlineBefore)
-                }
-                WorkflowSuggestionKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
-                WorkflowSuggestionKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
-                _ => None,
-            }
-        }
-    }
 }

crates/assistant/src/inline_assistant.rs 🔗

@@ -1280,12 +1280,6 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
     })
 }
 
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
-pub enum InitialInsertion {
-    NewlineBefore,
-    NewlineAfter,
-}
-
 #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 pub struct InlineAssistId(usize);
 

crates/gpui/src/elements/img.rs 🔗

@@ -351,10 +351,13 @@ impl Asset for ImageAsset {
                     let mut body = Vec::new();
                     response.body_mut().read_to_end(&mut body).await?;
                     if !response.status().is_success() {
+                        let mut body = String::from_utf8_lossy(&body).into_owned();
+                        let first_line = body.lines().next().unwrap_or("").trim_end();
+                        body.truncate(first_line.len());
                         return Err(ImageCacheError::BadStatus {
                             uri,
                             status: response.status(),
-                            body: String::from_utf8_lossy(&body).into_owned(),
+                            body,
                         });
                     }
                     body

crates/language_model/src/language_model.rs 🔗

@@ -8,7 +8,7 @@ pub mod settings;
 
 use anyhow::Result;
 use client::{Client, UserStore};
-use futures::{future::BoxFuture, stream::BoxStream};
+use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
 use gpui::{
     AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
 };
@@ -76,7 +76,7 @@ pub trait LanguageModel: Send + Sync {
         description: String,
         schema: serde_json::Value,
         cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>>;
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 
     #[cfg(any(test, feature = "test-support"))]
     fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
@@ -92,10 +92,11 @@ impl dyn LanguageModel {
     ) -> impl 'static + Future<Output = Result<T>> {
         let schema = schemars::schema_for!(T);
         let schema_json = serde_json::to_value(&schema).unwrap();
-        let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
+        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
         async move {
-            let response = request.await?;
-            Ok(serde_json::from_value(response)?)
+            let stream = stream.await?;
+            let response = stream.try_collect::<String>().await?;
+            Ok(serde_json::from_str(&response)?)
         }
     }
 }

crates/language_model/src/provider/anthropic.rs 🔗

@@ -7,7 +7,7 @@ use anthropic::AnthropicError;
 use anyhow::{anyhow, Context as _, Result};
 use collections::BTreeMap;
 use editor::{Editor, EditorElement, EditorStyle};
-use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
 use gpui::{
     AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
     View, WhiteSpace,
@@ -264,29 +264,6 @@ pub fn count_anthropic_tokens(
 }
 
 impl AnthropicModel {
-    fn request_completion(
-        &self,
-        request: anthropic::Request,
-        cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<anthropic::Response>> {
-        let http_client = self.http_client.clone();
-
-        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
-            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
-            (state.api_key.clone(), settings.api_url.clone())
-        }) else {
-            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
-        };
-
-        async move {
-            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
-            anthropic::complete(http_client.as_ref(), &api_url, &api_key, request)
-                .await
-                .context("failed to retrieve completion")
-        }
-        .boxed()
-    }
-
     fn stream_completion(
         &self,
         request: anthropic::Request,
@@ -381,7 +358,7 @@ impl LanguageModel for AnthropicModel {
         tool_description: String,
         input_schema: serde_json::Value,
         cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         let mut request = request.into_anthropic(self.model.tool_model_id().into());
         request.tool_choice = Some(anthropic::ToolChoice::Tool {
             name: tool_name.clone(),
@@ -392,25 +369,16 @@ impl LanguageModel for AnthropicModel {
             input_schema,
         }];
 
-        let response = self.request_completion(request, cx);
+        let response = self.stream_completion(request, cx);
         self.request_limiter
             .run(async move {
                 let response = response.await?;
-                response
-                    .content
-                    .into_iter()
-                    .find_map(|content| {
-                        if let anthropic::Content::ToolUse { name, input, .. } = content {
-                            if name == tool_name {
-                                Some(input)
-                            } else {
-                                None
-                            }
-                        } else {
-                            None
-                        }
-                    })
-                    .context("tool not used")
+                Ok(anthropic::extract_tool_args_from_events(
+                    tool_name,
+                    Box::pin(response.map_err(|e| anyhow!(e))),
+                )
+                .await?
+                .boxed())
             })
             .boxed()
     }

crates/language_model/src/provider/cloud.rs 🔗

@@ -5,18 +5,21 @@ use crate::{
     LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
 };
 use anthropic::AnthropicError;
-use anyhow::{anyhow, bail, Context as _, Result};
+use anyhow::{anyhow, Result};
 use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use collections::BTreeMap;
 use feature_flags::{FeatureFlagAppExt, ZedPro};
-use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
+use futures::{
+    future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
+    TryStreamExt as _,
+};
 use gpui::{
     AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
     Subscription, Task,
 };
 use http_client::{AsyncBody, HttpClient, Method, Response};
 use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::value::RawValue;
 use settings::{Settings, SettingsStore};
 use smol::{
@@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
-                    let body = BufReader::new(response.into_body());
-                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
-                        let mut buffer = String::new();
-                        match body.read_line(&mut buffer).await {
-                            Ok(0) => Ok(None),
-                            Ok(_) => {
-                                let event: anthropic::Event = serde_json::from_str(&buffer)
-                                    .context("failed to parse Anthropic event")?;
-                                Ok(Some((event, body)))
-                            }
-                            Err(err) => Err(AnthropicError::Other(err.into())),
-                        }
-                    });
-
-                    Ok(anthropic::extract_text_from_events(stream))
+                    Ok(anthropic::extract_text_from_events(
+                        response_lines(response).map_err(AnthropicError::Other),
+                    ))
                 });
                 async move {
                     Ok(future
@@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
-                    let body = BufReader::new(response.into_body());
-                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
-                        let mut buffer = String::new();
-                        match body.read_line(&mut buffer).await {
-                            Ok(0) => Ok(None),
-                            Ok(_) => {
-                                let event: open_ai::ResponseStreamEvent =
-                                    serde_json::from_str(&buffer)?;
-                                Ok(Some((event, body)))
-                            }
-                            Err(e) => Err(e.into()),
-                        }
-                    });
-
-                    Ok(open_ai::extract_text_from_events(stream))
+                    Ok(open_ai::extract_text_from_events(response_lines(response)))
                 });
                 async move { Ok(future.await?.boxed()) }.boxed()
             }
@@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
-                    let body = BufReader::new(response.into_body());
-                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
-                        let mut buffer = String::new();
-                        match body.read_line(&mut buffer).await {
-                            Ok(0) => Ok(None),
-                            Ok(_) => {
-                                let event: google_ai::GenerateContentResponse =
-                                    serde_json::from_str(&buffer)?;
-                                Ok(Some((event, body)))
-                            }
-                            Err(e) => Err(e.into()),
-                        }
-                    });
-
-                    Ok(google_ai::extract_text_from_events(stream))
+                    Ok(google_ai::extract_text_from_events(response_lines(
+                        response,
+                    )))
                 });
                 async move { Ok(future.await?.boxed()) }.boxed()
             }
@@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
-                    let body = BufReader::new(response.into_body());
-                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
-                        let mut buffer = String::new();
-                        match body.read_line(&mut buffer).await {
-                            Ok(0) => Ok(None),
-                            Ok(_) => {
-                                let event: open_ai::ResponseStreamEvent =
-                                    serde_json::from_str(&buffer)?;
-                                Ok(Some((event, body)))
-                            }
-                            Err(e) => Err(e.into()),
-                        }
-                    });
-
-                    Ok(open_ai::extract_text_from_events(stream))
+                    Ok(open_ai::extract_text_from_events(response_lines(response)))
                 });
                 async move { Ok(future.await?.boxed()) }.boxed()
             }
@@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel {
         tool_description: String,
         input_schema: serde_json::Value,
         _cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let client = self.client.clone();
+        let llm_api_token = self.llm_api_token.clone();
+
         match &self.model {
             CloudModel::Anthropic(model) => {
-                let client = self.client.clone();
                 let mut request = request.into_anthropic(model.tool_model_id().into());
                 request.tool_choice = Some(anthropic::ToolChoice::Tool {
                     name: tool_name.clone(),
@@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel {
                     input_schema,
                 }];
 
-                let llm_api_token = self.llm_api_token.clone();
                 self.request_limiter
                     .run(async move {
                         let response = Self::perform_llm_completion(
@@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel {
                         )
                         .await?;
 
-                        let mut tool_use_index = None;
-                        let mut tool_input = String::new();
-                        let mut body = BufReader::new(response.into_body());
-                        let mut line = String::new();
-                        while body.read_line(&mut line).await? > 0 {
-                            let event: anthropic::Event = serde_json::from_str(&line)?;
-                            line.clear();
-
-                            match event {
-                                anthropic::Event::ContentBlockStart {
-                                    content_block,
-                                    index,
-                                } => {
-                                    if let anthropic::Content::ToolUse { name, .. } = content_block
-                                    {
-                                        if name == tool_name {
-                                            tool_use_index = Some(index);
-                                        }
-                                    }
-                                }
-                                anthropic::Event::ContentBlockDelta { index, delta } => match delta
-                                {
-                                    anthropic::ContentDelta::TextDelta { .. } => {}
-                                    anthropic::ContentDelta::InputJsonDelta { partial_json } => {
-                                        if Some(index) == tool_use_index {
-                                            tool_input.push_str(&partial_json);
-                                        }
-                                    }
-                                },
-                                anthropic::Event::ContentBlockStop { index } => {
-                                    if Some(index) == tool_use_index {
-                                        return Ok(serde_json::from_str(&tool_input)?);
-                                    }
-                                }
-                                _ => {}
-                            }
-                        }
-
-                        if tool_use_index.is_some() {
-                            Err(anyhow!("tool content incomplete"))
-                        } else {
-                            Err(anyhow!("tool not used"))
-                        }
+                        Ok(anthropic::extract_tool_args_from_events(
+                            tool_name,
+                            Box::pin(response_lines(response)),
+                        )
+                        .await?
+                        .boxed())
                     })
                     .boxed()
             }
             CloudModel::OpenAi(model) => {
                 let mut request = request.into_open_ai(model.id().into());
-                let client = self.client.clone();
-                let mut function = open_ai::FunctionDefinition {
-                    name: tool_name.clone(),
-                    description: None,
-                    parameters: None,
-                };
-                let func = open_ai::ToolDefinition::Function {
-                    function: function.clone(),
-                };
-                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
-                // Fill in description and params separately, as they're not needed for tool_choice field.
-                function.description = Some(tool_description);
-                function.parameters = Some(input_schema);
-                request.tools = vec![open_ai::ToolDefinition::Function { function }];
+                request.tool_choice = Some(open_ai::ToolChoice::Other(
+                    open_ai::ToolDefinition::Function {
+                        function: open_ai::FunctionDefinition {
+                            name: tool_name.clone(),
+                            description: None,
+                            parameters: None,
+                        },
+                    },
+                ));
+                request.tools = vec![open_ai::ToolDefinition::Function {
+                    function: open_ai::FunctionDefinition {
+                        name: tool_name.clone(),
+                        description: Some(tool_description),
+                        parameters: Some(input_schema),
+                    },
+                }];
 
-                let llm_api_token = self.llm_api_token.clone();
                 self.request_limiter
                     .run(async move {
                         let response = Self::perform_llm_completion(
@@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel {
                         )
                         .await?;
 
-                        let mut body = BufReader::new(response.into_body());
-                        let mut line = String::new();
-                        let mut load_state = None;
-
-                        while body.read_line(&mut line).await? > 0 {
-                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
-                            line.clear();
-
-                            for choice in part.choices {
-                                let Some(tool_calls) = choice.delta.tool_calls else {
-                                    continue;
-                                };
-
-                                for call in tool_calls {
-                                    if let Some(func) = call.function {
-                                        if func.name.as_deref() == Some(tool_name.as_str()) {
-                                            load_state = Some((String::default(), call.index));
-                                        }
-                                        if let Some((arguments, (output, index))) =
-                                            func.arguments.zip(load_state.as_mut())
-                                        {
-                                            if call.index == *index {
-                                                output.push_str(&arguments);
-                                            }
-                                        }
-                                    }
-                                }
-                            }
-                        }
-
-                        if let Some((arguments, _)) = load_state {
-                            return Ok(serde_json::from_str(&arguments)?);
-                        } else {
-                            bail!("tool not used");
-                        }
+                        Ok(open_ai::extract_tool_args_from_events(
+                            tool_name,
+                            Box::pin(response_lines(response)),
+                        )
+                        .await?
+                        .boxed())
                     })
                     .boxed()
             }
@@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel {
             CloudModel::Zed(model) => {
                 // All Zed models are OpenAI-based at the time of writing.
                 let mut request = request.into_open_ai(model.id().into());
-                let client = self.client.clone();
-                let mut function = open_ai::FunctionDefinition {
-                    name: tool_name.clone(),
-                    description: None,
-                    parameters: None,
-                };
-                let func = open_ai::ToolDefinition::Function {
-                    function: function.clone(),
-                };
-                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
-                // Fill in description and params separately, as they're not needed for tool_choice field.
-                function.description = Some(tool_description);
-                function.parameters = Some(input_schema);
-                request.tools = vec![open_ai::ToolDefinition::Function { function }];
+                request.tool_choice = Some(open_ai::ToolChoice::Other(
+                    open_ai::ToolDefinition::Function {
+                        function: open_ai::FunctionDefinition {
+                            name: tool_name.clone(),
+                            description: None,
+                            parameters: None,
+                        },
+                    },
+                ));
+                request.tools = vec![open_ai::ToolDefinition::Function {
+                    function: open_ai::FunctionDefinition {
+                        name: tool_name.clone(),
+                        description: Some(tool_description),
+                        parameters: Some(input_schema),
+                    },
+                }];
 
-                let llm_api_token = self.llm_api_token.clone();
                 self.request_limiter
                     .run(async move {
                         let response = Self::perform_llm_completion(
@@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel {
                         )
                         .await?;
 
-                        let mut body = BufReader::new(response.into_body());
-                        let mut line = String::new();
-                        let mut load_state = None;
-
-                        while body.read_line(&mut line).await? > 0 {
-                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
-                            line.clear();
-
-                            for choice in part.choices {
-                                let Some(tool_calls) = choice.delta.tool_calls else {
-                                    continue;
-                                };
-
-                                for call in tool_calls {
-                                    if let Some(func) = call.function {
-                                        if func.name.as_deref() == Some(tool_name.as_str()) {
-                                            load_state = Some((String::default(), call.index));
-                                        }
-                                        if let Some((arguments, (output, index))) =
-                                            func.arguments.zip(load_state.as_mut())
-                                        {
-                                            if call.index == *index {
-                                                output.push_str(&arguments);
-                                            }
-                                        }
-                                    }
-                                }
-                            }
-                        }
-                        if let Some((arguments, _)) = load_state {
-                            return Ok(serde_json::from_str(&arguments)?);
-                        } else {
-                            bail!("tool not used");
-                        }
+                        Ok(open_ai::extract_tool_args_from_events(
+                            tool_name,
+                            Box::pin(response_lines(response)),
+                        )
+                        .await?
+                        .boxed())
                     })
                     .boxed()
             }
@@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel {
     }
 }
 
+fn response_lines<T: DeserializeOwned>(
+    response: Response<AsyncBody>,
+) -> impl Stream<Item = Result<T>> {
+    futures::stream::try_unfold(
+        (String::new(), BufReader::new(response.into_body())),
+        move |(mut line, mut body)| async {
+            match body.read_line(&mut line).await {
+                Ok(0) => Ok(None),
+                Ok(_) => {
+                    let event: T = serde_json::from_str(&line)?;
+                    line.clear();
+                    Ok(Some((event, (line, body))))
+                }
+                Err(e) => Err(e.into()),
+            }
+        },
+    )
+}
+
 impl LlmApiToken {
     async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
         let lock = self.0.upgradable_read().await;

crates/language_model/src/provider/copilot_chat.rs 🔗

@@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel {
         _description: String,
         _schema: serde_json::Value,
         _cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         future::ready(Err(anyhow!("not implemented"))).boxed()
     }
 }

crates/language_model/src/provider/fake.rs 🔗

@@ -3,16 +3,11 @@ use crate::{
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
     LanguageModelRequest,
 };
-use anyhow::Context as _;
-use futures::{
-    channel::{mpsc, oneshot},
-    future::BoxFuture,
-    stream::BoxStream,
-    FutureExt, StreamExt,
-};
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 use http_client::Result;
 use parking_lot::Mutex;
+use serde::Serialize;
 use std::sync::Arc;
 use ui::WindowContext;
 
@@ -90,7 +85,7 @@ pub struct ToolUseRequest {
 #[derive(Default)]
 pub struct FakeLanguageModel {
     current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
-    current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
+    current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
 }
 
 impl FakeLanguageModel {
@@ -130,25 +125,11 @@ impl FakeLanguageModel {
         self.end_completion_stream(self.pending_completions().last().unwrap());
     }
 
-    pub fn respond_to_tool_use(
-        &self,
-        tool_call: &ToolUseRequest,
-        response: Result<serde_json::Value>,
-    ) {
-        let mut current_tool_call_txs = self.current_tool_use_txs.lock();
-        if let Some(index) = current_tool_call_txs
-            .iter()
-            .position(|(call, _)| call == tool_call)
-        {
-            let (_, tx) = current_tool_call_txs.remove(index);
-            tx.send(response).unwrap();
-        }
-    }
-
-    pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
+    pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
+        let response = serde_json::to_string(&response).unwrap();
         let mut current_tool_call_txs = self.current_tool_use_txs.lock();
         let (_, tx) = current_tool_call_txs.pop().unwrap();
-        tx.send(response).unwrap();
+        tx.unbounded_send(response).unwrap();
     }
 }
 
@@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel {
         description: String,
         schema: serde_json::Value,
         _cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
-        let (tx, rx) = oneshot::channel();
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let (tx, rx) = mpsc::unbounded();
         let tool_call = ToolUseRequest {
             request,
             name,
@@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel {
             schema,
         };
         self.current_tool_use_txs.lock().push((tool_call, tx));
-        async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
+        async move { Ok(rx.map(Ok).boxed()) }.boxed()
     }
 
     fn as_fake(&self) -> &Self {

crates/language_model/src/provider/google.rs 🔗

@@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel {
         _description: String,
         _schema: serde_json::Value,
         _cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
         future::ready(Err(anyhow!("not implemented"))).boxed()
     }
 }

crates/language_model/src/provider/ollama.rs 🔗

@@ -6,7 +6,6 @@ use ollama::{
     get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
     ChatResponseDelta, OllamaToolCall,
 };
-use serde_json::Value;
 use settings::{Settings, SettingsStore};
 use std::{sync::Arc, time::Duration};
 use ui::{prelude::*, ButtonLike, Indicator};
@@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel {
         tool_description: String,
         schema: serde_json::Value,
         cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         use ollama::{OllamaFunctionTool, OllamaTool};
         let function = OllamaFunctionTool {
             name: tool_name.clone(),
@@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel {
         self.request_limiter
             .run(async move {
                 let response = response.await?;
-                let ChatMessage::Assistant {
-                    tool_calls,
-                    content,
-                } = response.message
-                else {
+                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
                     bail!("message does not have an assistant role");
                 };
                 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
                     for call in tool_calls {
                         let OllamaToolCall::Function(function) = call;
                         if function.name == tool_name {
-                            return Ok(function.arguments);
+                            return Ok(futures::stream::once(async move {
+                                Ok(function.arguments.to_string())
+                            })
+                            .boxed());
                         }
                     }
-                } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
-                    // Parse content as arguments.
-                    return Ok(args);
                 } else {
                     bail!("assistant message does not have any tool calls");
                 };

crates/language_model/src/provider/open_ai.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::{anyhow, bail, Result};
+use anyhow::{anyhow, Result};
 use collections::BTreeMap;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{future::BoxFuture, FutureExt, StreamExt};
@@ -243,6 +243,7 @@ impl OpenAiLanguageModel {
         async move { Ok(future.await?.boxed()) }.boxed()
     }
 }
+
 impl LanguageModel for OpenAiLanguageModel {
     fn id(&self) -> LanguageModelId {
         self.id.clone()
@@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel {
         tool_description: String,
         schema: serde_json::Value,
         cx: &AsyncAppContext,
-    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
         let mut request = request.into_open_ai(self.model.id().into());
-        let mut function = FunctionDefinition {
-            name: tool_name.clone(),
-            description: None,
-            parameters: None,
-        };
-        let func = ToolDefinition::Function {
-            function: function.clone(),
-        };
-        request.tool_choice = Some(ToolChoice::Other(func.clone()));
-        // Fill in description and params separately, as they're not needed for tool_choice field.
-        function.description = Some(tool_description);
-        function.parameters = Some(schema);
-        request.tools = vec![ToolDefinition::Function { function }];
+        request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
+            function: FunctionDefinition {
+                name: tool_name.clone(),
+                description: None,
+                parameters: None,
+            },
+        }));
+        request.tools = vec![ToolDefinition::Function {
+            function: FunctionDefinition {
+                name: tool_name.clone(),
+                description: Some(tool_description),
+                parameters: Some(schema),
+            },
+        }];
+
         let response = self.stream_completion(request, cx);
         self.request_limiter
             .run(async move {
-                let mut response = response.await?;
-
-                // Call arguments are gonna be streamed in over multiple chunks.
-                let mut load_state = None;
-                while let Some(Ok(part)) = response.next().await {
-                    for choice in part.choices {
-                        let Some(tool_calls) = choice.delta.tool_calls else {
-                            continue;
-                        };
-
-                        for call in tool_calls {
-                            if let Some(func) = call.function {
-                                if func.name.as_deref() == Some(tool_name.as_str()) {
-                                    load_state = Some((String::default(), call.index));
-                                }
-                                if let Some((arguments, (output, index))) =
-                                    func.arguments.zip(load_state.as_mut())
-                                {
-                                    if call.index == *index {
-                                        output.push_str(&arguments);
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                if let Some((arguments, _)) = load_state {
-                    return Ok(serde_json::from_str(&arguments)?);
-                } else {
-                    bail!("tool not used");
-                }
+                let response = response.await?;
+                Ok(
+                    open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
+                        .await?
+                        .boxed(),
+                )
             })
             .boxed()
     }

crates/ollama/src/ollama.rs 🔗

@@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use serde_json::Value;
+use serde_json::{value::RawValue, Value};
 use std::{convert::TryFrom, sync::Arc, time::Duration};
 
 pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -92,7 +92,7 @@ impl Model {
     }
 }
 
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
 #[serde(tag = "role", rename_all = "lowercase")]
 pub enum ChatMessage {
     Assistant {
@@ -107,16 +107,16 @@ pub enum ChatMessage {
     },
 }
 
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
 #[serde(rename_all = "lowercase")]
 pub enum OllamaToolCall {
     Function(OllamaFunctionCall),
 }
 
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
 pub struct OllamaFunctionCall {
     pub name: String,
-    pub arguments: Value,
+    pub arguments: Box<RawValue>,
 }
 
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

crates/open_ai/src/open_ai.rs 🔗

@@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
-use std::{convert::TryFrom, future::Future, time::Duration};
+use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
 use strum::EnumIter;
 
 pub use supported_countries::*;
@@ -384,6 +384,57 @@ pub fn embed<'a>(
     }
 }
 
+pub async fn extract_tool_args_from_events(
+    tool_name: String,
+    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+) -> Result<impl Send + Stream<Item = Result<String>>> {
+    let mut tool_use_index = None;
+    let mut first_chunk = None;
+    while let Some(event) = events.next().await {
+        let call = event?.choices.into_iter().find_map(|choice| {
+            choice.delta.tool_calls?.into_iter().find_map(|call| {
+                if call.function.as_ref()?.name.as_deref()? == tool_name {
+                    Some(call)
+                } else {
+                    None
+                }
+            })
+        });
+        if let Some(call) = call {
+            tool_use_index = Some(call.index);
+            first_chunk = call.function.and_then(|func| func.arguments);
+            break;
+        }
+    }
+
+    let Some(tool_use_index) = tool_use_index else {
+        return Err(anyhow!("tool not used"));
+    };
+
+    Ok(events.filter_map(move |event| {
+        let result = match event {
+            Err(error) => Some(Err(error)),
+            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
+                choice.delta.tool_calls?.into_iter().find_map(|call| {
+                    if call.index == tool_use_index {
+                        let func = call.function?;
+                        let mut arguments = func.arguments?;
+                        if let Some(mut first_chunk) = first_chunk.take() {
+                            first_chunk.push_str(&arguments);
+                            arguments = first_chunk
+                        }
+                        Some(Ok(arguments))
+                    } else {
+                        None
+                    }
+                })
+            }),
+        };
+
+        async move { result }
+    }))
+}
+
 pub fn extract_text_from_events(
     response: impl Stream<Item = Result<ResponseStreamEvent>>,
 ) -> impl Stream<Item = Result<String>> {