@@ -1,9 +1,11 @@
use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
+use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -11,12 +13,14 @@ use language_model::{
LanguageModelRequest, RateLimiter, Role,
};
use ollama::{
- ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model,
- stream_chat_completion,
+ ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
+ OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::sync::atomic::{AtomicU64, Ordering};
use std::{collections::BTreeMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt;
@@ -47,6 +51,8 @@ pub struct AvailableModel {
pub max_tokens: usize,
/// The number of seconds to keep the connection open after the last request
pub keep_alive: Option<KeepAlive>,
+ /// Whether the model supports tools
+ pub supports_tools: bool,
}
pub struct OllamaLanguageModelProvider {
@@ -68,26 +74,44 @@ impl State {
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
- let http_client = self.http_client.clone();
+ let http_client = Arc::clone(&self.http_client);
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(async move |this, cx| {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
- let mut models: Vec<ollama::Model> = models
+ let tasks = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
- .map(|model| ollama::Model::new(&model.name, None, None))
- .collect();
+ .map(|model| {
+ let http_client = Arc::clone(&http_client);
+ let api_url = api_url.clone();
+ async move {
+ let name = model.name.as_str();
+ let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
+ let ollama_model =
+ ollama::Model::new(name, None, None, capabilities.supports_tools());
+ Ok(ollama_model)
+ }
+ });
+
+ // Rate-limit capability fetches
+ // since there is an arbitrary number of models available
+ let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
+ .buffer_unordered(5)
+ .collect::<Vec<Result<_>>>()
+ .await
+ .into_iter()
+ .collect::<Result<Vec<_>>>()?;
- models.sort_by(|a, b| a.name.cmp(&b.name));
+ ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
- this.available_models = models;
+ this.available_models = ollama_models;
cx.notify();
})
})
@@ -189,6 +213,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: model.keep_alive.clone(),
+ supports_tools: model.supports_tools,
},
);
}
@@ -269,7 +294,7 @@ impl OllamaLanguageModel {
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
- tools: vec![],
+ tools: request.tools.into_iter().map(tool_into_ollama).collect(),
}
}
}
@@ -292,7 +317,7 @@ impl LanguageModel for OllamaLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
+ self.model.supports_tools
}
fn telemetry_id(&self) -> String {
@@ -341,39 +366,100 @@ impl LanguageModel for OllamaLanguageModel {
};
let future = self.request_limiter.stream(async move {
- let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(delta) => {
- let content = match delta.message {
- ChatMessage::User { content } => content,
- ChatMessage::Assistant { content, .. } => content,
- ChatMessage::System { content } => content,
- };
- Some(Ok(content))
- }
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
+ let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
+ let stream = map_to_language_model_completion_events(stream);
Ok(stream)
});
- async move {
- Ok(future
- .await?
- .map(|result| {
- result
- .map(LanguageModelCompletionEvent::Text)
- .map_err(LanguageModelCompletionError::Other)
- })
- .boxed())
- }
- .boxed()
+ future.map_ok(|f| f.boxed()).boxed()
}
}
+fn map_to_language_model_completion_events(
+ stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ // Used for creating unique tool use ids
+ static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
+
+ struct State {
+ stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
+ used_tools: bool,
+ }
+
+ // We need to create a ToolUse and Stop event from a single
+ // response from the original stream
+ let stream = stream::unfold(
+ State {
+ stream,
+ used_tools: false,
+ },
+ async move |mut state| {
+ let response = state.stream.next().await?;
+
+ let delta = match response {
+ Ok(delta) => delta,
+ Err(e) => {
+ let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
+ return Some((vec![event], state));
+ }
+ };
+
+ let mut events = Vec::new();
+
+ match delta.message {
+ ChatMessage::User { content } => {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+ ChatMessage::System { content } => {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+ ChatMessage::Assistant {
+ content,
+ tool_calls,
+ } => {
+ // Check for tool calls
+ if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
+ match tool_call {
+ OllamaToolCall::Function(function) => {
+ let tool_id = format!(
+ "{}-{}",
+ &function.name,
+ TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
+ );
+ let event =
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ id: LanguageModelToolUseId::from(tool_id),
+ name: Arc::from(function.name),
+ raw_input: function.arguments.to_string(),
+ input: function.arguments,
+ is_input_complete: true,
+ });
+ events.push(Ok(event));
+ state.used_tools = true;
+ }
+ }
+ } else {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+ }
+ };
+
+ if delta.done {
+ if state.used_tools {
+ state.used_tools = false;
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ } else {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ }
+
+ Some((events, state))
+ },
+ );
+
+ stream.flat_map(futures::stream::iter)
+}
+
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
@@ -509,3 +595,13 @@ impl Render for ConfigurationView {
}
}
}
+
+fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
+ ollama::OllamaTool::Function {
+ function: OllamaFunctionTool {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ }
+}
@@ -2,42 +2,11 @@ use anyhow::{Context as _, Result, anyhow};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
-use serde_json::{Value, value::RawValue};
-use std::{convert::TryFrom, sync::Arc, time::Duration};
+use serde_json::Value;
+use std::{sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
- System,
-}
-
-impl TryFrom<String> for Role {
- type Error = anyhow::Error;
-
- fn try_from(value: String) -> Result<Self> {
- match value.as_str() {
- "user" => Ok(Self::User),
- "assistant" => Ok(Self::Assistant),
- "system" => Ok(Self::System),
- _ => Err(anyhow!("invalid role '{value}'")),
- }
- }
-}
-
-impl From<Role> for String {
- fn from(val: Role) -> Self {
- match val {
- Role::User => "user".to_owned(),
- Role::Assistant => "assistant".to_owned(),
- Role::System => "system".to_owned(),
- }
- }
-}
-
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(untagged)]
@@ -68,6 +37,7 @@ pub struct Model {
pub display_name: Option<String>,
pub max_tokens: usize,
pub keep_alive: Option<KeepAlive>,
+ pub supports_tools: bool,
}
fn get_max_tokens(name: &str) -> usize {
@@ -93,7 +63,12 @@ fn get_max_tokens(name: &str) -> usize {
}
impl Model {
- pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
+ pub fn new(
+ name: &str,
+ display_name: Option<&str>,
+ max_tokens: Option<usize>,
+ supports_tools: bool,
+ ) -> Self {
Self {
name: name.to_owned(),
display_name: display_name
@@ -101,6 +76,7 @@ impl Model {
.or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
keep_alive: Some(KeepAlive::indefinite()),
+ supports_tools,
}
}
@@ -141,7 +117,7 @@ pub enum OllamaToolCall {
#[derive(Serialize, Deserialize, Debug)]
pub struct OllamaFunctionCall {
pub name: String,
- pub arguments: Box<RawValue>,
+ pub arguments: Value,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -229,6 +205,19 @@ pub struct ModelDetails {
pub quantization_level: String,
}
+#[derive(Deserialize, Debug)]
+pub struct ModelShow {
+ #[serde(default)]
+ pub capabilities: Vec<String>,
+}
+
+impl ModelShow {
+ pub fn supports_tools(&self) -> bool {
+ // .contains expects &String, which would require an additional allocation
+ self.capabilities.iter().any(|v| v == "tools")
+ }
+}
+
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
@@ -244,14 +233,14 @@ pub async fn complete(
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
+
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
if response.status().is_success() {
- let mut body = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
- let mut body = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
@@ -279,13 +268,9 @@ pub async fn stream_chat_completion(
Ok(reader
.lines()
- .filter_map(move |line| async move {
- match line {
- Ok(line) => {
- Some(serde_json::from_str(&line).context("Unable to parse chat response"))
- }
- Err(e) => Some(Err(e.into())),
- }
+ .map(|line| match line {
+ Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
+ Err(e) => Err(e.into()),
})
.boxed())
} else {
@@ -332,6 +317,33 @@ pub async fn get_models(
}
}
+/// Fetch details of a model, used to determine model capabilities
+pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
+ let uri = format!("{api_url}/api/show");
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(
+ serde_json::json!({ "model": model }).to_string(),
+ ))?;
+
+ let mut response = client.send(request).await?;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ if response.status().is_success() {
+ let details: ModelShow = serde_json::from_str(body.as_str())?;
+ Ok(details)
+ } else {
+ Err(anyhow!(
+ "Failed to connect to Ollama API: {} {}",
+ response.status(),
+ body,
+ ))
+ }
+}
+
/// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate");
@@ -339,12 +351,13 @@ pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &s
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
- .body(AsyncBody::from(serde_json::to_string(
- &serde_json::json!({
+ .body(AsyncBody::from(
+ serde_json::json!({
"model": model,
"keep_alive": "15m",
- }),
- )?))?;
+ })
+ .to_string(),
+ ))?;
let mut response = client.send(request).await?;
@@ -361,3 +374,161 @@ pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &s
))
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn parse_completion() {
+ let response = serde_json::json!({
+ "model": "llama3.2",
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {
+ "role": "assistant",
+ "content": "Hello! How are you today?"
+ },
+ "done": true,
+ "total_duration": 5191566416u64,
+ "load_duration": 2154458,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 383809000,
+ "eval_count": 298,
+ "eval_duration": 4799921000u64
+ });
+ let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
+ }
+
+ #[test]
+ fn parse_streaming_completion() {
+ let partial = serde_json::json!({
+ "model": "llama3.2",
+ "created_at": "2023-08-04T08:52:19.385406455-07:00",
+ "message": {
+ "role": "assistant",
+ "content": "The",
+ "images": null
+ },
+ "done": false
+ });
+
+ let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
+
+ let last = serde_json::json!({
+ "model": "llama3.2",
+ "created_at": "2023-08-04T19:22:45.499127Z",
+ "message": {
+ "role": "assistant",
+ "content": ""
+ },
+ "done": true,
+ "total_duration": 4883583458u64,
+ "load_duration": 1334875,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 342546000,
+ "eval_count": 282,
+ "eval_duration": 4535599000u64
+ });
+
+ let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
+ }
+
+ #[test]
+ fn parse_tool_call() {
+ let response = serde_json::json!({
+ "model": "llama3.2:3b",
+ "created_at": "2025-04-28T20:02:02.140489Z",
+ "message": {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "weather",
+ "arguments": {
+ "city": "london",
+ }
+ }
+ }
+ ]
+ },
+ "done_reason": "stop",
+ "done": true,
+ "total_duration": 2758629166u64,
+ "load_duration": 1770059875,
+ "prompt_eval_count": 147,
+ "prompt_eval_duration": 684637583,
+ "eval_count": 16,
+ "eval_duration": 302561917,
+ });
+
+ let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
+ match result.message {
+ ChatMessage::Assistant {
+ content,
+ tool_calls,
+ } => {
+ assert!(content.is_empty());
+ assert!(tool_calls.is_some_and(|v| !v.is_empty()));
+ }
+ _ => panic!("Deserialized wrong role"),
+ }
+ }
+
+ #[test]
+ fn parse_show_model() {
+ let response = serde_json::json!({
+ "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
+ "details": {
+ "parent_model": "",
+ "format": "gguf",
+ "family": "llama",
+ "families": ["llama"],
+ "parameter_size": "3.2B",
+ "quantization_level": "Q4_K_M"
+ },
+ "model_info": {
+ "general.architecture": "llama",
+ "general.basename": "Llama-3.2",
+ "general.file_type": 15,
+ "general.finetune": "Instruct",
+ "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
+ "general.parameter_count": 3212749888u64,
+ "general.quantization_version": 2,
+ "general.size_label": "3B",
+ "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
+ "general.type": "model",
+ "llama.attention.head_count": 24,
+ "llama.attention.head_count_kv": 8,
+ "llama.attention.key_length": 128,
+ "llama.attention.layer_norm_rms_epsilon": 0.00001,
+ "llama.attention.value_length": 128,
+ "llama.block_count": 28,
+ "llama.context_length": 131072,
+ "llama.embedding_length": 3072,
+ "llama.feed_forward_length": 8192,
+ "llama.rope.dimension_count": 128,
+ "llama.rope.freq_base": 500000,
+ "llama.vocab_size": 128256,
+ "tokenizer.ggml.bos_token_id": 128000,
+ "tokenizer.ggml.eos_token_id": 128009,
+ "tokenizer.ggml.merges": null,
+ "tokenizer.ggml.model": "gpt2",
+ "tokenizer.ggml.pre": "llama-bpe",
+ "tokenizer.ggml.token_type": null,
+ "tokenizer.ggml.tokens": null
+ },
+ "tensors": [
+ { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
+ { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
+ ],
+ "capabilities": ["completion", "tools"],
+ "modified_at": "2025-04-29T21:24:41.445877632+03:00"
+ });
+
+ let result: ModelShow = serde_json::from_value(response).unwrap();
+ assert!(result.supports_tools());
+ assert!(result.capabilities.contains(&"tools".to_string()));
+ assert!(result.capabilities.contains(&"completion".to_string()));
+ }
+}