ollama: Add tool call support (#29563)

tidely , Antonio Scandurra , and Nathan Sobo created

The goal of this PR is to support tool calls using ollama. A lot of the
serialization work was done in
https://github.com/zed-industries/zed/pull/15803 however the abstraction
over language models always disables tools.

## Changelog:

- Use `serde_json::Value` inside `OllamaFunctionCall` just as it's used
in `OllamaFunctionCall`. This fixes deserialization of ollama tool
calls.
- Added deserialization tests using json from official ollama api docs.
- Fetch model capabilities during model enumeration from ollama provider
- Added `supports_tools` setting to manually configure if a model
supports tools

## TODO:

- [x] Fix tool call serialization/deserialization
- [x] Fetch model capabilities from ollama api
- [x] Add tests for parsing model capabilities 
- [ ] Documentation for `supports_tools` field for ollama language model
config
- [ ] Convert between generic language model types
- [x] Pass tools to ollama

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>

Change summary

crates/assistant_settings/src/assistant_settings.rs |   7 
crates/language_models/src/provider/ollama.rs       | 170 +++++++--
crates/ollama/src/ollama.rs                         | 271 ++++++++++++--
3 files changed, 360 insertions(+), 88 deletions(-)

Detailed changes

crates/assistant_settings/src/assistant_settings.rs 🔗

@@ -315,7 +315,12 @@ impl AssistantSettingsContent {
                                 _ => None,
                             };
                             settings.provider = Some(AssistantProviderContentV1::Ollama {
-                                default_model: Some(ollama::Model::new(&model, None, None)),
+                                default_model: Some(ollama::Model::new(
+                                    &model,
+                                    None,
+                                    None,
+                                    language_model.supports_tools(),
+                                )),
                                 api_url,
                             });
                         }

crates/language_models/src/provider/ollama.rs 🔗

@@ -1,9 +1,11 @@
 use anyhow::{Result, anyhow};
 use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
+use futures::{Stream, TryFutureExt, stream};
 use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
 use http_client::HttpClient;
 use language_model::{
     AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
+    LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason,
 };
 use language_model::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -11,12 +13,14 @@ use language_model::{
     LanguageModelRequest, RateLimiter, Role,
 };
 use ollama::{
-    ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model,
-    stream_chat_completion,
+    ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
+    OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
 };
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::sync::atomic::{AtomicU64, Ordering};
 use std::{collections::BTreeMap, sync::Arc};
 use ui::{ButtonLike, Indicator, List, prelude::*};
 use util::ResultExt;
@@ -47,6 +51,8 @@ pub struct AvailableModel {
     pub max_tokens: usize,
     /// The number of seconds to keep the connection open after the last request
     pub keep_alive: Option<KeepAlive>,
+    /// Whether the model supports tools
+    pub supports_tools: bool,
 }
 
 pub struct OllamaLanguageModelProvider {
@@ -68,26 +74,44 @@ impl State {
 
     fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
         let settings = &AllLanguageModelSettings::get_global(cx).ollama;
-        let http_client = self.http_client.clone();
+        let http_client = Arc::clone(&self.http_client);
         let api_url = settings.api_url.clone();
 
         // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
         cx.spawn(async move |this, cx| {
             let models = get_models(http_client.as_ref(), &api_url, None).await?;
 
-            let mut models: Vec<ollama::Model> = models
+            let tasks = models
                 .into_iter()
                 // Since there is no metadata from the Ollama API
                 // indicating which models are embedding models,
                 // simply filter out models with "-embed" in their name
                 .filter(|model| !model.name.contains("-embed"))
-                .map(|model| ollama::Model::new(&model.name, None, None))
-                .collect();
+                .map(|model| {
+                    let http_client = Arc::clone(&http_client);
+                    let api_url = api_url.clone();
+                    async move {
+                        let name = model.name.as_str();
+                        let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
+                        let ollama_model =
+                            ollama::Model::new(name, None, None, capabilities.supports_tools());
+                        Ok(ollama_model)
+                    }
+                });
+
+            // Rate-limit capability fetches
+            // since there is an arbitrary number of models available
+            let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
+                .buffer_unordered(5)
+                .collect::<Vec<Result<_>>>()
+                .await
+                .into_iter()
+                .collect::<Result<Vec<_>>>()?;
 
-            models.sort_by(|a, b| a.name.cmp(&b.name));
+            ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
 
             this.update(cx, |this, cx| {
-                this.available_models = models;
+                this.available_models = ollama_models;
                 cx.notify();
             })
         })
@@ -189,6 +213,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
                     display_name: model.display_name.clone(),
                     max_tokens: model.max_tokens,
                     keep_alive: model.keep_alive.clone(),
+                    supports_tools: model.supports_tools,
                 },
             );
         }
@@ -269,7 +294,7 @@ impl OllamaLanguageModel {
                 temperature: request.temperature.or(Some(1.0)),
                 ..Default::default()
             }),
-            tools: vec![],
+            tools: request.tools.into_iter().map(tool_into_ollama).collect(),
         }
     }
 }
@@ -292,7 +317,7 @@ impl LanguageModel for OllamaLanguageModel {
     }
 
     fn supports_tools(&self) -> bool {
-        false
+        self.model.supports_tools
     }
 
     fn telemetry_id(&self) -> String {
@@ -341,39 +366,100 @@ impl LanguageModel for OllamaLanguageModel {
         };
 
         let future = self.request_limiter.stream(async move {
-            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
-            let stream = response
-                .filter_map(|response| async move {
-                    match response {
-                        Ok(delta) => {
-                            let content = match delta.message {
-                                ChatMessage::User { content } => content,
-                                ChatMessage::Assistant { content, .. } => content,
-                                ChatMessage::System { content } => content,
-                            };
-                            Some(Ok(content))
-                        }
-                        Err(error) => Some(Err(error)),
-                    }
-                })
-                .boxed();
+            let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
+            let stream = map_to_language_model_completion_events(stream);
             Ok(stream)
         });
 
-        async move {
-            Ok(future
-                .await?
-                .map(|result| {
-                    result
-                        .map(LanguageModelCompletionEvent::Text)
-                        .map_err(LanguageModelCompletionError::Other)
-                })
-                .boxed())
-        }
-        .boxed()
+        future.map_ok(|f| f.boxed()).boxed()
     }
 }
 
+fn map_to_language_model_completion_events(
+    stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+    // Used for creating unique tool use ids
+    static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
+
+    struct State {
+        stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
+        used_tools: bool,
+    }
+
+    // We need to create a ToolUse and Stop event from a single
+    // response from the original stream
+    let stream = stream::unfold(
+        State {
+            stream,
+            used_tools: false,
+        },
+        async move |mut state| {
+            let response = state.stream.next().await?;
+
+            let delta = match response {
+                Ok(delta) => delta,
+                Err(e) => {
+                    let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
+                    return Some((vec![event], state));
+                }
+            };
+
+            let mut events = Vec::new();
+
+            match delta.message {
+                ChatMessage::User { content } => {
+                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+                }
+                ChatMessage::System { content } => {
+                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+                }
+                ChatMessage::Assistant {
+                    content,
+                    tool_calls,
+                } => {
+                    // Check for tool calls
+                    if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
+                        match tool_call {
+                            OllamaToolCall::Function(function) => {
+                                let tool_id = format!(
+                                    "{}-{}",
+                                    &function.name,
+                                    TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
+                                );
+                                let event =
+                                    LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+                                        id: LanguageModelToolUseId::from(tool_id),
+                                        name: Arc::from(function.name),
+                                        raw_input: function.arguments.to_string(),
+                                        input: function.arguments,
+                                        is_input_complete: true,
+                                    });
+                                events.push(Ok(event));
+                                state.used_tools = true;
+                            }
+                        }
+                    } else {
+                        events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+                    }
+                }
+            };
+
+            if delta.done {
+                if state.used_tools {
+                    state.used_tools = false;
+                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+                } else {
+                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+                }
+            }
+
+            Some((events, state))
+        },
+    );
+
+    stream.flat_map(futures::stream::iter)
+}
+
 struct ConfigurationView {
     state: gpui::Entity<State>,
     loading_models_task: Option<Task<()>>,
@@ -509,3 +595,13 @@ impl Render for ConfigurationView {
         }
     }
 }
+
+fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
+    ollama::OllamaTool::Function {
+        function: OllamaFunctionTool {
+            name: tool.name,
+            description: Some(tool.description),
+            parameters: Some(tool.input_schema),
+        },
+    }
+}

crates/ollama/src/ollama.rs 🔗

@@ -2,42 +2,11 @@ use anyhow::{Context as _, Result, anyhow};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
 use serde::{Deserialize, Serialize};
-use serde_json::{Value, value::RawValue};
-use std::{convert::TryFrom, sync::Arc, time::Duration};
+use serde_json::Value;
+use std::{sync::Arc, time::Duration};
 
 pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
-    User,
-    Assistant,
-    System,
-}
-
-impl TryFrom<String> for Role {
-    type Error = anyhow::Error;
-
-    fn try_from(value: String) -> Result<Self> {
-        match value.as_str() {
-            "user" => Ok(Self::User),
-            "assistant" => Ok(Self::Assistant),
-            "system" => Ok(Self::System),
-            _ => Err(anyhow!("invalid role '{value}'")),
-        }
-    }
-}
-
-impl From<Role> for String {
-    fn from(val: Role) -> Self {
-        match val {
-            Role::User => "user".to_owned(),
-            Role::Assistant => "assistant".to_owned(),
-            Role::System => "system".to_owned(),
-        }
-    }
-}
-
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 #[serde(untagged)]
@@ -68,6 +37,7 @@ pub struct Model {
     pub display_name: Option<String>,
     pub max_tokens: usize,
     pub keep_alive: Option<KeepAlive>,
+    pub supports_tools: bool,
 }
 
 fn get_max_tokens(name: &str) -> usize {
@@ -93,7 +63,12 @@ fn get_max_tokens(name: &str) -> usize {
 }
 
 impl Model {
-    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
+    pub fn new(
+        name: &str,
+        display_name: Option<&str>,
+        max_tokens: Option<usize>,
+        supports_tools: bool,
+    ) -> Self {
         Self {
             name: name.to_owned(),
             display_name: display_name
@@ -101,6 +76,7 @@ impl Model {
                 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
             max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
             keep_alive: Some(KeepAlive::indefinite()),
+            supports_tools,
         }
     }
 
@@ -141,7 +117,7 @@ pub enum OllamaToolCall {
 #[derive(Serialize, Deserialize, Debug)]
 pub struct OllamaFunctionCall {
     pub name: String,
-    pub arguments: Box<RawValue>,
+    pub arguments: Value,
 }
 
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -229,6 +205,19 @@ pub struct ModelDetails {
     pub quantization_level: String,
 }
 
+#[derive(Deserialize, Debug)]
+pub struct ModelShow {
+    #[serde(default)]
+    pub capabilities: Vec<String>,
+}
+
+impl ModelShow {
+    pub fn supports_tools(&self) -> bool {
+        // .contains expects &String, which would require an additional allocation
+        self.capabilities.iter().any(|v| v == "tools")
+    }
+}
+
 pub async fn complete(
     client: &dyn HttpClient,
     api_url: &str,
@@ -244,14 +233,14 @@ pub async fn complete(
     let request = request_builder.body(AsyncBody::from(serialized_request))?;
 
     let mut response = client.send(request).await?;
+
+    let mut body = Vec::new();
+    response.body_mut().read_to_end(&mut body).await?;
+
     if response.status().is_success() {
-        let mut body = Vec::new();
-        response.body_mut().read_to_end(&mut body).await?;
         let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
         Ok(response_message)
     } else {
-        let mut body = Vec::new();
-        response.body_mut().read_to_end(&mut body).await?;
         let body_str = std::str::from_utf8(&body)?;
         Err(anyhow!(
             "Failed to connect to API: {} {}",
@@ -279,13 +268,9 @@ pub async fn stream_chat_completion(
 
         Ok(reader
             .lines()
-            .filter_map(move |line| async move {
-                match line {
-                    Ok(line) => {
-                        Some(serde_json::from_str(&line).context("Unable to parse chat response"))
-                    }
-                    Err(e) => Some(Err(e.into())),
-                }
+            .map(|line| match line {
+                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
+                Err(e) => Err(e.into()),
             })
             .boxed())
     } else {
@@ -332,6 +317,33 @@ pub async fn get_models(
     }
 }
 
+/// Fetch details of a model, used to determine model capabilities
+pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
+    let uri = format!("{api_url}/api/show");
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json")
+        .body(AsyncBody::from(
+            serde_json::json!({ "model": model }).to_string(),
+        ))?;
+
+    let mut response = client.send(request).await?;
+    let mut body = String::new();
+    response.body_mut().read_to_string(&mut body).await?;
+
+    if response.status().is_success() {
+        let details: ModelShow = serde_json::from_str(body.as_str())?;
+        Ok(details)
+    } else {
+        Err(anyhow!(
+            "Failed to connect to Ollama API: {} {}",
+            response.status(),
+            body,
+        ))
+    }
+}
+
 /// Sends an empty request to Ollama to trigger loading the model
 pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
     let uri = format!("{api_url}/api/generate");
@@ -339,12 +351,13 @@ pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &s
         .method(Method::POST)
         .uri(uri)
         .header("Content-Type", "application/json")
-        .body(AsyncBody::from(serde_json::to_string(
-            &serde_json::json!({
+        .body(AsyncBody::from(
+            serde_json::json!({
                 "model": model,
                 "keep_alive": "15m",
-            }),
-        )?))?;
+            })
+            .to_string(),
+        ))?;
 
     let mut response = client.send(request).await?;
 
@@ -361,3 +374,161 @@ pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &s
         ))
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn parse_completion() {
+        let response = serde_json::json!({
+        "model": "llama3.2",
+        "created_at": "2023-12-12T14:13:43.416799Z",
+        "message": {
+            "role": "assistant",
+            "content": "Hello! How are you today?"
+        },
+        "done": true,
+        "total_duration": 5191566416u64,
+        "load_duration": 2154458,
+        "prompt_eval_count": 26,
+        "prompt_eval_duration": 383809000,
+        "eval_count": 298,
+        "eval_duration": 4799921000u64
+        });
+        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
+    }
+
+    #[test]
+    fn parse_streaming_completion() {
+        let partial = serde_json::json!({
+        "model": "llama3.2",
+        "created_at": "2023-08-04T08:52:19.385406455-07:00",
+        "message": {
+            "role": "assistant",
+            "content": "The",
+            "images": null
+        },
+        "done": false
+        });
+
+        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
+
+        let last = serde_json::json!({
+        "model": "llama3.2",
+        "created_at": "2023-08-04T19:22:45.499127Z",
+        "message": {
+            "role": "assistant",
+            "content": ""
+        },
+        "done": true,
+        "total_duration": 4883583458u64,
+        "load_duration": 1334875,
+        "prompt_eval_count": 26,
+        "prompt_eval_duration": 342546000,
+        "eval_count": 282,
+        "eval_duration": 4535599000u64
+        });
+
+        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
+    }
+
+    #[test]
+    fn parse_tool_call() {
+        let response = serde_json::json!({
+            "model": "llama3.2:3b",
+            "created_at": "2025-04-28T20:02:02.140489Z",
+            "message": {
+                "role": "assistant",
+                "content": "",
+                "tool_calls": [
+                    {
+                        "function": {
+                            "name": "weather",
+                            "arguments": {
+                                "city": "london",
+                            }
+                        }
+                    }
+                ]
+            },
+            "done_reason": "stop",
+            "done": true,
+            "total_duration": 2758629166u64,
+            "load_duration": 1770059875,
+            "prompt_eval_count": 147,
+            "prompt_eval_duration": 684637583,
+            "eval_count": 16,
+            "eval_duration": 302561917,
+        });
+
+        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
+        match result.message {
+            ChatMessage::Assistant {
+                content,
+                tool_calls,
+            } => {
+                assert!(content.is_empty());
+                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
+            }
+            _ => panic!("Deserialized wrong role"),
+        }
+    }
+
+    #[test]
+    fn parse_show_model() {
+        let response = serde_json::json!({
+            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
+            "details": {
+                "parent_model": "",
+                "format": "gguf",
+                "family": "llama",
+                "families": ["llama"],
+                "parameter_size": "3.2B",
+                "quantization_level": "Q4_K_M"
+            },
+            "model_info": {
+                "general.architecture": "llama",
+                "general.basename": "Llama-3.2",
+                "general.file_type": 15,
+                "general.finetune": "Instruct",
+                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
+                "general.parameter_count": 3212749888u64,
+                "general.quantization_version": 2,
+                "general.size_label": "3B",
+                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
+                "general.type": "model",
+                "llama.attention.head_count": 24,
+                "llama.attention.head_count_kv": 8,
+                "llama.attention.key_length": 128,
+                "llama.attention.layer_norm_rms_epsilon": 0.00001,
+                "llama.attention.value_length": 128,
+                "llama.block_count": 28,
+                "llama.context_length": 131072,
+                "llama.embedding_length": 3072,
+                "llama.feed_forward_length": 8192,
+                "llama.rope.dimension_count": 128,
+                "llama.rope.freq_base": 500000,
+                "llama.vocab_size": 128256,
+                "tokenizer.ggml.bos_token_id": 128000,
+                "tokenizer.ggml.eos_token_id": 128009,
+                "tokenizer.ggml.merges": null,
+                "tokenizer.ggml.model": "gpt2",
+                "tokenizer.ggml.pre": "llama-bpe",
+                "tokenizer.ggml.token_type": null,
+                "tokenizer.ggml.tokens": null
+            },
+            "tensors": [
+                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
+                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
+            ],
+            "capabilities": ["completion", "tools"],
+            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
+        });
+
+        let result: ModelShow = serde_json::from_value(response).unwrap();
+        assert!(result.supports_tools());
+        assert!(result.capabilities.contains(&"tools".to_string()));
+        assert!(result.capabilities.contains(&"completion".to_string()));
+    }
+}