assistant: Use tools in other providers (#15803)

Piotr Osiewicz created

- [x] OpenAI
- [ ] ~Google~ Moved into a separate branch at:
https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've
ran into issues with having the API digest our schema without tripping
over itself - the function call parameters are malformed and whatnot. We
can resume from that branch if needed.
- [x] Ollama
- [x] Cloud
- [ ] ~Copilot Chat (?)~

Release Notes:

- Added tool calling capabilities to OpenAI and Ollama models.

Change summary

crates/language_model/src/provider/cloud.rs   | 139 ++++++++++++++++++++
crates/language_model/src/provider/ollama.rs  |  73 +++++++++-
crates/language_model/src/provider/open_ai.rs | 133 ++++++++++++++-----
crates/ollama/src/ollama.rs                   |  86 ++++++++++++
crates/open_ai/src/open_ai.rs                 |  25 ++-
5 files changed, 392 insertions(+), 64 deletions(-)

Detailed changes

crates/language_model/src/provider/cloud.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
     LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
     LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
 };
-use anyhow::{anyhow, Context as _, Result};
+use anyhow::{anyhow, bail, Context as _, Result};
 use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use collections::BTreeMap;
 use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
@@ -634,14 +634,143 @@ impl LanguageModel for CloudLanguageModel {
                     })
                     .boxed()
             }
-            CloudModel::OpenAi(_) => {
-                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
+            CloudModel::OpenAi(model) => {
+                let mut request = request.into_open_ai(model.id().into());
+                let client = self.client.clone();
+                let mut function = open_ai::FunctionDefinition {
+                    name: tool_name.clone(),
+                    description: None,
+                    parameters: None,
+                };
+                let func = open_ai::ToolDefinition::Function {
+                    function: function.clone(),
+                };
+                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
+                // Fill in description and params separately, as they're not needed for tool_choice field.
+                function.description = Some(tool_description);
+                function.parameters = Some(input_schema);
+                request.tools = vec![open_ai::ToolDefinition::Function { function }];
+                self.request_limiter
+                    .run(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let response = client
+                            .request_stream(proto::StreamCompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::OpenAi as i32,
+                                request,
+                            })
+                            .await?;
+                        // Call arguments are gonna be streamed in over multiple chunks.
+                        let mut load_state = None;
+                        let mut response = response.map(
+                            |item: Result<
+                                proto::StreamCompleteWithLanguageModelResponse,
+                                anyhow::Error,
+                            >| {
+                                Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+                                    serde_json::from_str(&item?.event)?,
+                                )
+                            },
+                        );
+                        while let Some(Ok(part)) = response.next().await {
+                            for choice in part.choices {
+                                let Some(tool_calls) = choice.delta.tool_calls else {
+                                    continue;
+                                };
+
+                                for call in tool_calls {
+                                    if let Some(func) = call.function {
+                                        if func.name.as_deref() == Some(tool_name.as_str()) {
+                                            load_state = Some((String::default(), call.index));
+                                        }
+                                        if let Some((arguments, (output, index))) =
+                                            func.arguments.zip(load_state.as_mut())
+                                        {
+                                            if call.index == *index {
+                                                output.push_str(&arguments);
+                                            }
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                        if let Some((arguments, _)) = load_state {
+                            return Ok(serde_json::from_str(&arguments)?);
+                        } else {
+                            bail!("tool not used");
+                        }
+                    })
+                    .boxed()
             }
             CloudModel::Google(_) => {
                 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
             }
-            CloudModel::Zed(_) => {
-                future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
+            CloudModel::Zed(model) => {
+                // All Zed models are OpenAI-based at the time of writing.
+                let mut request = request.into_open_ai(model.id().into());
+                let client = self.client.clone();
+                let mut function = open_ai::FunctionDefinition {
+                    name: tool_name.clone(),
+                    description: None,
+                    parameters: None,
+                };
+                let func = open_ai::ToolDefinition::Function {
+                    function: function.clone(),
+                };
+                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
+                // Fill in description and params separately, as they're not needed for tool_choice field.
+                function.description = Some(tool_description);
+                function.parameters = Some(input_schema);
+                request.tools = vec![open_ai::ToolDefinition::Function { function }];
+                self.request_limiter
+                    .run(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let response = client
+                            .request_stream(proto::StreamCompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::OpenAi as i32,
+                                request,
+                            })
+                            .await?;
+                        // Call arguments are gonna be streamed in over multiple chunks.
+                        let mut load_state = None;
+                        let mut response = response.map(
+                            |item: Result<
+                                proto::StreamCompleteWithLanguageModelResponse,
+                                anyhow::Error,
+                            >| {
+                                Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+                                    serde_json::from_str(&item?.event)?,
+                                )
+                            },
+                        );
+                        while let Some(Ok(part)) = response.next().await {
+                            for choice in part.choices {
+                                let Some(tool_calls) = choice.delta.tool_calls else {
+                                    continue;
+                                };
+
+                                for call in tool_calls {
+                                    if let Some(func) = call.function {
+                                        if func.name.as_deref() == Some(tool_name.as_str()) {
+                                            load_state = Some((String::default(), call.index));
+                                        }
+                                        if let Some((arguments, (output, index))) =
+                                            func.arguments.zip(load_state.as_mut())
+                                        {
+                                            if call.index == *index {
+                                                output.push_str(&arguments);
+                                            }
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                        if let Some((arguments, _)) = load_state {
+                            return Ok(serde_json::from_str(&arguments)?);
+                        } else {
+                            bail!("tool not used");
+                        }
+                    })
+                    .boxed()
             }
         }
     }

crates/language_model/src/provider/ollama.rs 🔗

@@ -1,12 +1,14 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, bail, Result};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
 use http_client::HttpClient;
 use ollama::{
     get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
+    ChatResponseDelta, OllamaToolCall,
 };
+use serde_json::Value;
 use settings::{Settings, SettingsStore};
-use std::{future, sync::Arc, time::Duration};
+use std::{sync::Arc, time::Duration};
 use ui::{prelude::*, ButtonLike, Indicator};
 use util::ResultExt;
 
@@ -184,6 +186,7 @@ impl OllamaLanguageModel {
                     },
                     Role::Assistant => ChatMessage::Assistant {
                         content: msg.content,
+                        tool_calls: None,
                     },
                     Role::System => ChatMessage::System {
                         content: msg.content,
@@ -198,8 +201,25 @@ impl OllamaLanguageModel {
                 temperature: Some(request.temperature),
                 ..Default::default()
             }),
+            tools: vec![],
         }
     }
+    fn request_completion(
+        &self,
+        request: ChatRequest,
+        cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
+        let http_client = self.http_client.clone();
+
+        let Ok(api_url) = cx.update(|cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+            settings.api_url.clone()
+        }) else {
+            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+        };
+
+        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
+    }
 }
 
 impl LanguageModel for OllamaLanguageModel {
@@ -269,7 +289,7 @@ impl LanguageModel for OllamaLanguageModel {
                         Ok(delta) => {
                             let content = match delta.message {
                                 ChatMessage::User { content } => content,
-                                ChatMessage::Assistant { content } => content,
+                                ChatMessage::Assistant { content, .. } => content,
                                 ChatMessage::System { content } => content,
                             };
                             Some(Ok(content))
@@ -286,13 +306,48 @@ impl LanguageModel for OllamaLanguageModel {
 
     fn use_any_tool(
         &self,
-        _request: LanguageModelRequest,
-        _name: String,
-        _description: String,
-        _schema: serde_json::Value,
-        _cx: &AsyncAppContext,
+        request: LanguageModelRequest,
+        tool_name: String,
+        tool_description: String,
+        schema: serde_json::Value,
+        cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<serde_json::Value>> {
-        future::ready(Err(anyhow!("not implemented"))).boxed()
+        use ollama::{OllamaFunctionTool, OllamaTool};
+        let function = OllamaFunctionTool {
+            name: tool_name.clone(),
+            description: Some(tool_description),
+            parameters: Some(schema),
+        };
+        let tools = vec![OllamaTool::Function { function }];
+        let request = self.to_ollama_request(request).with_tools(tools);
+        let response = self.request_completion(request, cx);
+        self.request_limiter
+            .run(async move {
+                let response = response.await?;
+                let ChatMessage::Assistant {
+                    tool_calls,
+                    content,
+                } = response.message
+                else {
+                    bail!("message does not have an assistant role");
+                };
+                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
+                    for call in tool_calls {
+                        let OllamaToolCall::Function(function) = call;
+                        if function.name == tool_name {
+                            return Ok(function.arguments);
+                        }
+                    }
+                } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
+                    // Parse content as arguments.
+                    return Ok(args);
+                } else {
+                    bail!("assistant message does not have any tool calls");
+                };
+
+                bail!("tool not used")
+            })
+            .boxed()
     }
 }
 

crates/language_model/src/provider/open_ai.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, bail, Result};
 use collections::BTreeMap;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{future::BoxFuture, FutureExt, StreamExt};
@@ -7,11 +7,13 @@ use gpui::{
     View, WhiteSpace,
 };
 use http_client::HttpClient;
-use open_ai::stream_completion;
+use open_ai::{
+    stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
+};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
-use std::{future, sync::Arc, time::Duration};
+use std::{sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::{prelude::*, Indicator};
@@ -206,6 +208,41 @@ pub struct OpenAiLanguageModel {
     request_limiter: RateLimiter,
 }
 
+impl OpenAiLanguageModel {
+    fn stream_completion(
+        &self,
+        request: open_ai::Request,
+        cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
+    {
+        let http_client = self.http_client.clone();
+        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).openai;
+            (
+                state.api_key.clone(),
+                settings.api_url.clone(),
+                settings.low_speed_timeout,
+            )
+        }) else {
+            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+        };
+
+        let future = self.request_limiter.stream(async move {
+            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+            let request = stream_completion(
+                http_client.as_ref(),
+                &api_url,
+                &api_key,
+                request,
+                low_speed_timeout,
+            );
+            let response = request.await?;
+            Ok(response)
+        });
+
+        async move { Ok(future.await?.boxed()) }.boxed()
+    }
+}
 impl LanguageModel for OpenAiLanguageModel {
     fn id(&self) -> LanguageModelId {
         self.id.clone()
@@ -245,44 +282,68 @@ impl LanguageModel for OpenAiLanguageModel {
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
         let request = request.into_open_ai(self.model.id().into());
-
-        let http_client = self.http_client.clone();
-        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
-            let settings = &AllLanguageModelSettings::get_global(cx).openai;
-            (
-                state.api_key.clone(),
-                settings.api_url.clone(),
-                settings.low_speed_timeout,
-            )
-        }) else {
-            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
-        };
-
-        let future = self.request_limiter.stream(async move {
-            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
-            let request = stream_completion(
-                http_client.as_ref(),
-                &api_url,
-                &api_key,
-                request,
-                low_speed_timeout,
-            );
-            let response = request.await?;
-            Ok(open_ai::extract_text_from_events(response).boxed())
-        });
-
-        async move { Ok(future.await?.boxed()) }.boxed()
+        let completions = self.stream_completion(request, cx);
+        async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
     }
 
     fn use_any_tool(
         &self,
-        _request: LanguageModelRequest,
-        _name: String,
-        _description: String,
-        _schema: serde_json::Value,
-        _cx: &AsyncAppContext,
+        request: LanguageModelRequest,
+        tool_name: String,
+        tool_description: String,
+        schema: serde_json::Value,
+        cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<serde_json::Value>> {
-        future::ready(Err(anyhow!("not implemented"))).boxed()
+        let mut request = request.into_open_ai(self.model.id().into());
+        let mut function = FunctionDefinition {
+            name: tool_name.clone(),
+            description: None,
+            parameters: None,
+        };
+        let func = ToolDefinition::Function {
+            function: function.clone(),
+        };
+        request.tool_choice = Some(ToolChoice::Other(func.clone()));
+        // Fill in description and params separately, as they're not needed for tool_choice field.
+        function.description = Some(tool_description);
+        function.parameters = Some(schema);
+        request.tools = vec![ToolDefinition::Function { function }];
+        let response = self.stream_completion(request, cx);
+        self.request_limiter
+            .run(async move {
+                let mut response = response.await?;
+
+                // Call arguments are gonna be streamed in over multiple chunks.
+                let mut load_state = None;
+                while let Some(Ok(part)) = response.next().await {
+                    for choice in part.choices {
+                        let Some(tool_calls) = choice.delta.tool_calls else {
+                            continue;
+                        };
+
+                        for call in tool_calls {
+                            if let Some(func) = call.function {
+                                if func.name.as_deref() == Some(tool_name.as_str()) {
+                                    load_state = Some((String::default(), call.index));
+                                }
+                                if let Some((arguments, (output, index))) =
+                                    func.arguments.zip(load_state.as_mut())
+                                {
+                                    if call.index == *index {
+                                        output.push_str(&arguments);
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                if let Some((arguments, _)) = load_state {
+                    return Ok(serde_json::from_str(&arguments)?);
+                } else {
+                    bail!("tool not used");
+                }
+            })
+            .boxed()
     }
 }
 

crates/ollama/src/ollama.rs 🔗

@@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
+use serde_json::Value;
 use std::{convert::TryFrom, sync::Arc, time::Duration};
 
 pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -94,22 +95,63 @@ impl Model {
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 #[serde(tag = "role", rename_all = "lowercase")]
 pub enum ChatMessage {
-    Assistant { content: String },
-    User { content: String },
-    System { content: String },
+    Assistant {
+        content: String,
+        tool_calls: Option<Vec<OllamaToolCall>>,
+    },
+    User {
+        content: String,
+    },
+    System {
+        content: String,
+    },
 }
 
-#[derive(Serialize)]
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum OllamaToolCall {
+    Function(OllamaFunctionCall),
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct OllamaFunctionCall {
+    pub name: String,
+    pub arguments: Value,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct OllamaFunctionTool {
+    pub name: String,
+    pub description: Option<String>,
+    pub parameters: Option<Value>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum OllamaTool {
+    Function { function: OllamaFunctionTool },
+}
+
+#[derive(Serialize, Debug)]
 pub struct ChatRequest {
     pub model: String,
     pub messages: Vec<ChatMessage>,
     pub stream: bool,
     pub keep_alive: KeepAlive,
     pub options: Option<ChatOptions>,
+    pub tools: Vec<OllamaTool>,
+}
+
+impl ChatRequest {
+    pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
+        self.stream = false;
+        self.tools = tools;
+        self
+    }
 }
 
 // https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
-#[derive(Serialize, Default)]
+#[derive(Serialize, Default, Debug)]
 pub struct ChatOptions {
     pub num_ctx: Option<usize>,
     pub num_predict: Option<isize>,
@@ -118,7 +160,7 @@ pub struct ChatOptions {
     pub top_p: Option<f32>,
 }
 
-#[derive(Deserialize)]
+#[derive(Deserialize, Debug)]
 pub struct ChatResponseDelta {
     #[allow(unused)]
     pub model: String,
@@ -162,6 +204,38 @@ pub struct ModelDetails {
     pub quantization_level: String,
 }
 
+pub async fn complete(
+    client: &dyn HttpClient,
+    api_url: &str,
+    request: ChatRequest,
+) -> Result<ChatResponseDelta> {
+    let uri = format!("{api_url}/api/chat");
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json");
+
+    let serialized_request = serde_json::to_string(&request)?;
+    let request = request_builder.body(AsyncBody::from(serialized_request))?;
+
+    let mut response = client.send(request).await?;
+    if response.status().is_success() {
+        let mut body = Vec::new();
+        response.body_mut().read_to_end(&mut body).await?;
+        let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
+        Ok(response_message)
+    } else {
+        let mut body = Vec::new();
+        response.body_mut().read_to_end(&mut body).await?;
+        let body_str = std::str::from_utf8(&body)?;
+        Err(anyhow!(
+            "Failed to connect to API: {} {}",
+            response.status(),
+            body_str
+        ))
+    }
+}
+
 pub async fn stream_chat_completion(
     client: &dyn HttpClient,
     api_url: &str,

crates/open_ai/src/open_ai.rs 🔗

@@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
-use serde_json::{Map, Value};
+use serde_json::Value;
 use std::{convert::TryFrom, future::Future, time::Duration};
 use strum::EnumIter;
 
@@ -121,25 +121,34 @@ pub struct Request {
     pub stop: Vec<String>,
     pub temperature: f32,
     #[serde(default, skip_serializing_if = "Option::is_none")]
-    pub tool_choice: Option<String>,
+    pub tool_choice: Option<ToolChoice>,
     #[serde(default, skip_serializing_if = "Vec::is_empty")]
     pub tools: Vec<ToolDefinition>,
 }
 
-#[derive(Debug, Deserialize, Serialize)]
-pub struct FunctionDefinition {
-    pub name: String,
-    pub description: Option<String>,
-    pub parameters: Option<Map<String, Value>>,
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ToolChoice {
+    Auto,
+    Required,
+    None,
+    Other(ToolDefinition),
 }
 
-#[derive(Deserialize, Serialize, Debug)]
+#[derive(Clone, Deserialize, Serialize, Debug)]
 #[serde(tag = "type", rename_all = "snake_case")]
 pub enum ToolDefinition {
     #[allow(dead_code)]
     Function { function: FunctionDefinition },
 }
 
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct FunctionDefinition {
+    pub name: String,
+    pub description: Option<String>,
+    pub parameters: Option<Value>,
+}
+
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 #[serde(tag = "role", rename_all = "lowercase")]
 pub enum RequestMessage {