From 769ec59162b316318233c90165dfeca173a716a3 Mon Sep 17 00:00:00 2001 From: tidely <43219534+tidely@users.noreply.github.com> Date: Mon, 5 May 2025 19:52:23 +0200 Subject: [PATCH] ollama: Add tool call support (#29563) The goal of this PR is to support tool calls using ollama. A lot of the serialization work was done in https://github.com/zed-industries/zed/pull/15803 however the abstraction over language models always disables tools. ## Changelog: - Use `serde_json::Value` inside `OllamaFunctionCall` just as it's used in `OllamaFunctionCall`. This fixes deserialization of ollama tool calls. - Added deserialization tests using json from official ollama api docs. - Fetch model capabilities during model enumeration from ollama provider - Added `supports_tools` setting to manually configure if a model supports tools ## TODO: - [x] Fix tool call serialization/deserialization - [x] Fetch model capabilities from ollama api - [x] Add tests for parsing model capabilities - [ ] Documentation for `supports_tools` field for ollama language model config - [ ] Convert between generic language model types - [x] Pass tools to ollama Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Nathan Sobo --- .../src/assistant_settings.rs | 7 +- crates/language_models/src/provider/ollama.rs | 170 ++++++++--- crates/ollama/src/ollama.rs | 271 ++++++++++++++---- 3 files changed, 360 insertions(+), 88 deletions(-) diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 50d3b7bdc5bb10392eca4abc96096d92564f6863..74206d437b4c136c04fa8d0fe6034e2d7c77f2be 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -315,7 +315,12 @@ impl AssistantSettingsContent { _ => None, }; settings.provider = Some(AssistantProviderContentV1::Ollama { - default_model: Some(ollama::Model::new(&model, None, None)), + default_model: Some(ollama::Model::new( + &model, + None, + None, + language_model.supports_tools(), + )), api_url, }); } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 28586b89b08c07fefaeace6bb0936a30395b5039..0273891c3486fa8e83a48e8e6c33d3a420f5f7e5 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -1,9 +1,11 @@ use anyhow::{Result, anyhow}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{Stream, TryFutureExt, stream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -11,12 +13,14 @@ use language_model::{ LanguageModelRequest, RateLimiter, Role, }; use ollama::{ - ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model, - stream_chat_completion, + ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, + OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; use std::{collections::BTreeMap, sync::Arc}; use ui::{ButtonLike, Indicator, List, prelude::*}; use util::ResultExt; @@ -47,6 +51,8 @@ pub struct AvailableModel { pub max_tokens: usize, /// The number of seconds to keep the connection open after the last request pub keep_alive: Option, + /// Whether the model supports tools + pub supports_tools: bool, } pub struct OllamaLanguageModelProvider { @@ -68,26 +74,44 @@ impl State { fn fetch_models(&mut self, cx: &mut Context) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).ollama; - let http_client = self.http_client.clone(); + let http_client = Arc::clone(&self.http_client); let api_url = settings.api_url.clone(); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(async move |this, cx| { let models = get_models(http_client.as_ref(), &api_url, None).await?; - let mut models: Vec = models + let tasks = models .into_iter() // Since there is no metadata from the Ollama API // indicating which models are embedding models, // simply filter out models with "-embed" in their name .filter(|model| !model.name.contains("-embed")) - .map(|model| ollama::Model::new(&model.name, None, None)) - .collect(); + .map(|model| { + let http_client = Arc::clone(&http_client); + let api_url = api_url.clone(); + async move { + let name = model.name.as_str(); + let capabilities = show_model(http_client.as_ref(), &api_url, name).await?; + let ollama_model = + ollama::Model::new(name, None, None, capabilities.supports_tools()); + Ok(ollama_model) + } + }); + + // Rate-limit capability fetches + // since there is an arbitrary number of models available + let mut ollama_models: Vec<_> = futures::stream::iter(tasks) + .buffer_unordered(5) + .collect::>>() + .await + .into_iter() + .collect::>>()?; - models.sort_by(|a, b| a.name.cmp(&b.name)); + ollama_models.sort_by(|a, b| a.name.cmp(&b.name)); this.update(cx, |this, cx| { - this.available_models = models; + this.available_models = ollama_models; cx.notify(); }) }) @@ -189,6 +213,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { display_name: model.display_name.clone(), max_tokens: model.max_tokens, keep_alive: model.keep_alive.clone(), + supports_tools: model.supports_tools, }, ); } @@ -269,7 +294,7 @@ impl OllamaLanguageModel { temperature: request.temperature.or(Some(1.0)), ..Default::default() }), - tools: vec![], + tools: request.tools.into_iter().map(tool_into_ollama).collect(), } } } @@ -292,7 +317,7 @@ impl LanguageModel for OllamaLanguageModel { } fn supports_tools(&self) -> bool { - false + self.model.supports_tools } fn telemetry_id(&self) -> String { @@ -341,39 +366,100 @@ impl LanguageModel for OllamaLanguageModel { }; let future = self.request_limiter.stream(async move { - let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(delta) => { - let content = match delta.message { - ChatMessage::User { content } => content, - ChatMessage::Assistant { content, .. } => content, - ChatMessage::System { content } => content, - }; - Some(Ok(content)) - } - Err(error) => Some(Err(error)), - } - }) - .boxed(); + let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; + let stream = map_to_language_model_completion_events(stream); Ok(stream) }); - async move { - Ok(future - .await? - .map(|result| { - result - .map(LanguageModelCompletionEvent::Text) - .map_err(LanguageModelCompletionError::Other) - }) - .boxed()) - } - .boxed() + future.map_ok(|f| f.boxed()).boxed() } } +fn map_to_language_model_completion_events( + stream: Pin> + Send>>, +) -> impl Stream> { + // Used for creating unique tool use ids + static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + + struct State { + stream: Pin> + Send>>, + used_tools: bool, + } + + // We need to create a ToolUse and Stop event from a single + // response from the original stream + let stream = stream::unfold( + State { + stream, + used_tools: false, + }, + async move |mut state| { + let response = state.stream.next().await?; + + let delta = match response { + Ok(delta) => delta, + Err(e) => { + let event = Err(LanguageModelCompletionError::Other(anyhow!(e))); + return Some((vec![event], state)); + } + }; + + let mut events = Vec::new(); + + match delta.message { + ChatMessage::User { content } => { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + ChatMessage::System { content } => { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + ChatMessage::Assistant { + content, + tool_calls, + } => { + // Check for tool calls + if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) { + match tool_call { + OllamaToolCall::Function(function) => { + let tool_id = format!( + "{}-{}", + &function.name, + TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed) + ); + let event = + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + id: LanguageModelToolUseId::from(tool_id), + name: Arc::from(function.name), + raw_input: function.arguments.to_string(), + input: function.arguments, + is_input_complete: true, + }); + events.push(Ok(event)); + state.used_tools = true; + } + } + } else { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + } + }; + + if delta.done { + if state.used_tools { + state.used_tools = false; + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } else { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + } + + Some((events, state)) + }, + ); + + stream.flat_map(futures::stream::iter) +} + struct ConfigurationView { state: gpui::Entity, loading_models_task: Option>, @@ -509,3 +595,13 @@ impl Render for ConfigurationView { } } } + +fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool { + ollama::OllamaTool::Function { + function: OllamaFunctionTool { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + } +} diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 209ba94c700027491eb08f7e53621d811881bda8..10436a35818565868b47f14463735277a66ff1eb 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -2,42 +2,11 @@ use anyhow::{Context as _, Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, value::RawValue}; -use std::{convert::TryFrom, sync::Arc, time::Duration}; +use serde_json::Value; +use std::{sync::Arc, time::Duration}; pub const OLLAMA_API_URL: &str = "http://localhost:11434"; -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl TryFrom for Role { - type Error = anyhow::Error; - - fn try_from(value: String) -> Result { - match value.as_str() { - "user" => Ok(Self::User), - "assistant" => Ok(Self::Assistant), - "system" => Ok(Self::System), - _ => Err(anyhow!("invalid role '{value}'")), - } - } -} - -impl From for String { - fn from(val: Role) -> Self { - match val { - Role::User => "user".to_owned(), - Role::Assistant => "assistant".to_owned(), - Role::System => "system".to_owned(), - } - } -} - #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(untagged)] @@ -68,6 +37,7 @@ pub struct Model { pub display_name: Option, pub max_tokens: usize, pub keep_alive: Option, + pub supports_tools: bool, } fn get_max_tokens(name: &str) -> usize { @@ -93,7 +63,12 @@ fn get_max_tokens(name: &str) -> usize { } impl Model { - pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option) -> Self { + pub fn new( + name: &str, + display_name: Option<&str>, + max_tokens: Option, + supports_tools: bool, + ) -> Self { Self { name: name.to_owned(), display_name: display_name @@ -101,6 +76,7 @@ impl Model { .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)), max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)), keep_alive: Some(KeepAlive::indefinite()), + supports_tools, } } @@ -141,7 +117,7 @@ pub enum OllamaToolCall { #[derive(Serialize, Deserialize, Debug)] pub struct OllamaFunctionCall { pub name: String, - pub arguments: Box, + pub arguments: Value, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -229,6 +205,19 @@ pub struct ModelDetails { pub quantization_level: String, } +#[derive(Deserialize, Debug)] +pub struct ModelShow { + #[serde(default)] + pub capabilities: Vec, +} + +impl ModelShow { + pub fn supports_tools(&self) -> bool { + // .contains expects &String, which would require an additional allocation + self.capabilities.iter().any(|v| v == "tools") + } +} + pub async fn complete( client: &dyn HttpClient, api_url: &str, @@ -244,14 +233,14 @@ pub async fn complete( let request = request_builder.body(AsyncBody::from(serialized_request))?; let mut response = client.send(request).await?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + if response.status().is_success() { - let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; let response_message: ChatResponseDelta = serde_json::from_slice(&body)?; Ok(response_message) } else { - let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; let body_str = std::str::from_utf8(&body)?; Err(anyhow!( "Failed to connect to API: {} {}", @@ -279,13 +268,9 @@ pub async fn stream_chat_completion( Ok(reader .lines() - .filter_map(move |line| async move { - match line { - Ok(line) => { - Some(serde_json::from_str(&line).context("Unable to parse chat response")) - } - Err(e) => Some(Err(e.into())), - } + .map(|line| match line { + Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"), + Err(e) => Err(e.into()), }) .boxed()) } else { @@ -332,6 +317,33 @@ pub async fn get_models( } } +/// Fetch details of a model, used to determine model capabilities +pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result { + let uri = format!("{api_url}/api/show"); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .body(AsyncBody::from( + serde_json::json!({ "model": model }).to_string(), + ))?; + + let mut response = client.send(request).await?; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let details: ModelShow = serde_json::from_str(body.as_str())?; + Ok(details) + } else { + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +} + /// Sends an empty request to Ollama to trigger loading the model pub async fn preload_model(client: Arc, api_url: &str, model: &str) -> Result<()> { let uri = format!("{api_url}/api/generate"); @@ -339,12 +351,13 @@ pub async fn preload_model(client: Arc, api_url: &str, model: &s .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .body(AsyncBody::from(serde_json::to_string( - &serde_json::json!({ + .body(AsyncBody::from( + serde_json::json!({ "model": model, "keep_alive": "15m", - }), - )?))?; + }) + .to_string(), + ))?; let mut response = client.send(request).await?; @@ -361,3 +374,161 @@ pub async fn preload_model(client: Arc, api_url: &str, model: &s )) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_completion() { + let response = serde_json::json!({ + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "Hello! How are you today?" + }, + "done": true, + "total_duration": 5191566416u64, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000u64 + }); + let _: ChatResponseDelta = serde_json::from_value(response).unwrap(); + } + + #[test] + fn parse_streaming_completion() { + let partial = serde_json::json!({ + "model": "llama3.2", + "created_at": "2023-08-04T08:52:19.385406455-07:00", + "message": { + "role": "assistant", + "content": "The", + "images": null + }, + "done": false + }); + + let _: ChatResponseDelta = serde_json::from_value(partial).unwrap(); + + let last = serde_json::json!({ + "model": "llama3.2", + "created_at": "2023-08-04T19:22:45.499127Z", + "message": { + "role": "assistant", + "content": "" + }, + "done": true, + "total_duration": 4883583458u64, + "load_duration": 1334875, + "prompt_eval_count": 26, + "prompt_eval_duration": 342546000, + "eval_count": 282, + "eval_duration": 4535599000u64 + }); + + let _: ChatResponseDelta = serde_json::from_value(last).unwrap(); + } + + #[test] + fn parse_tool_call() { + let response = serde_json::json!({ + "model": "llama3.2:3b", + "created_at": "2025-04-28T20:02:02.140489Z", + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "weather", + "arguments": { + "city": "london", + } + } + } + ] + }, + "done_reason": "stop", + "done": true, + "total_duration": 2758629166u64, + "load_duration": 1770059875, + "prompt_eval_count": 147, + "prompt_eval_duration": 684637583, + "eval_count": 16, + "eval_duration": 302561917, + }); + + let result: ChatResponseDelta = serde_json::from_value(response).unwrap(); + match result.message { + ChatMessage::Assistant { + content, + tool_calls, + } => { + assert!(content.is_empty()); + assert!(tool_calls.is_some_and(|v| !v.is_empty())); + } + _ => panic!("Deserialized wrong role"), + } + } + + #[test] + fn parse_show_model() { + let response = serde_json::json!({ + "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...", + "details": { + "parent_model": "", + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "3.2B", + "quantization_level": "Q4_K_M" + }, + "model_info": { + "general.architecture": "llama", + "general.basename": "Llama-3.2", + "general.file_type": 15, + "general.finetune": "Instruct", + "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"], + "general.parameter_count": 3212749888u64, + "general.quantization_version": 2, + "general.size_label": "3B", + "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"], + "general.type": "model", + "llama.attention.head_count": 24, + "llama.attention.head_count_kv": 8, + "llama.attention.key_length": 128, + "llama.attention.layer_norm_rms_epsilon": 0.00001, + "llama.attention.value_length": 128, + "llama.block_count": 28, + "llama.context_length": 131072, + "llama.embedding_length": 3072, + "llama.feed_forward_length": 8192, + "llama.rope.dimension_count": 128, + "llama.rope.freq_base": 500000, + "llama.vocab_size": 128256, + "tokenizer.ggml.bos_token_id": 128000, + "tokenizer.ggml.eos_token_id": 128009, + "tokenizer.ggml.merges": null, + "tokenizer.ggml.model": "gpt2", + "tokenizer.ggml.pre": "llama-bpe", + "tokenizer.ggml.token_type": null, + "tokenizer.ggml.tokens": null + }, + "tensors": [ + { "name": "rope_freqs.weight", "type": "F32", "shape": [64] }, + { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] } + ], + "capabilities": ["completion", "tools"], + "modified_at": "2025-04-29T21:24:41.445877632+03:00" + }); + + let result: ModelShow = serde_json::from_value(response).unwrap(); + assert!(result.supports_tools()); + assert!(result.capabilities.contains(&"tools".to_string())); + assert!(result.capabilities.contains(&"completion".to_string())); + } +}