From 4cb8d6f40ed1427353dcba5a10fcc4b22da1a365 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 11 Jun 2024 17:35:27 -0700 Subject: [PATCH] Ollama Provider for Assistant (#12902) Closes #4424. A few design decisions that may need some rethinking or later PRs: * Other providers have a check for authentication. I use this opportunity to fetch the models which doubles as a way of finding out if the Ollama server is running. * Ollama has _no_ API for getting the max tokens per model * Ollama has _no_ API for getting the current token count https://github.com/ollama/ollama/issues/1716 * Ollama does allow setting the `num_ctx` so I've defaulted this to 4096. It can be overridden in settings. * Ollama models will be "slow" to start inference because they're loading the model into memory. It's faster after that. There's no UI affordance to show that the model is being loaded. Release Notes: - Added an Ollama Provider for the assistant. If you have [Ollama](https://ollama.com/) running locally on your machine, you can enable it in your settings under: ```jsonc "assistant": { "version": "1", "provider": { "name": "ollama", // Recommended setting to allow for model startup "low_speed_timeout_in_seconds": 30, } } ``` Chat like usual image Interact with any model from the [Ollama Library](https://ollama.com/library) image Open up the terminal to download new models via `ollama pull`: ![image](https://github.com/zed-industries/zed/assets/836375/af7ec411-76bf-41c7-ba81-64bbaeea98a8) --- Cargo.lock | 14 + Cargo.toml | 2 + crates/assistant/Cargo.toml | 1 + crates/assistant/src/assistant.rs | 8 +- crates/assistant/src/assistant_settings.rs | 49 ++++ crates/assistant/src/completion_provider.rs | 59 +++++ .../src/completion_provider/ollama.rs | 246 ++++++++++++++++++ crates/ollama/Cargo.toml | 22 ++ crates/ollama/src/ollama.rs | 224 ++++++++++++++++ 9 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 crates/assistant/src/completion_provider/ollama.rs create mode 100644 crates/ollama/Cargo.toml create mode 100644 crates/ollama/src/ollama.rs diff --git a/Cargo.lock b/Cargo.lock index c18c904f6815811338198fc0e5a7418ba7fa95c4..fee065736954d94f4bbe2367830b4e4c64403a66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,7 @@ dependencies = [ "log", "menu", "multi_buffer", + "ollama", "open_ai", "ordered-float 2.10.0", "parking_lot", @@ -6921,6 +6922,19 @@ dependencies = [ "cc", ] +[[package]] +name = "ollama" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "http 0.1.0", + "isahc", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "once_cell" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index 5f001e5d294d2c3ff8298d38f21038cc4b5ea407..79510e808e7fe0b12f8744990f000d2d0b425057 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ members = [ "crates/multi_buffer", "crates/node_runtime", "crates/notifications", + "crates/ollama", "crates/open_ai", "crates/outline", "crates/picker", @@ -207,6 +208,7 @@ menu = { path = "crates/menu" } multi_buffer = { path = "crates/multi_buffer" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } +ollama = { path = "crates/ollama" } open_ai = { path = "crates/open_ai" } outline = { path = "crates/outline" } picker = { path = "crates/picker" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index eaa6fc5b730b912d7d4d59348ec04955c0bdc9ba..77f0bc4ae0ad3aa19d243ace4b7ffff2e57adbb1 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -35,6 +35,7 @@ language.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true +ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } ordered-float.workspace = true parking_lot.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index bb795b10344316bfa6c512d56556dde0365f6a34..07488fdc5b8f2f87c39e30f125cb38d2b29d5683 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -12,7 +12,7 @@ mod streaming_diff; pub use assistant_panel::AssistantPanel; -use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel}; +use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; @@ -91,6 +91,7 @@ pub enum LanguageModel { Cloud(CloudModel), OpenAi(OpenAiModel), Anthropic(AnthropicModel), + Ollama(OllamaModel), } impl Default for LanguageModel { @@ -105,6 +106,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), } } @@ -113,6 +115,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.display_name().into(), LanguageModel::Anthropic(model) => model.display_name().into(), LanguageModel::Cloud(model) => model.display_name().into(), + LanguageModel::Ollama(model) => model.display_name().into(), } } @@ -121,6 +124,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.max_token_count(), LanguageModel::Anthropic(model) => model.max_token_count(), LanguageModel::Cloud(model) => model.max_token_count(), + LanguageModel::Ollama(model) => model.max_token_count(), } } @@ -129,6 +133,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.id(), LanguageModel::Anthropic(model) => model.id(), LanguageModel::Cloud(model) => model.id(), + LanguageModel::Ollama(model) => model.id(), } } } @@ -179,6 +184,7 @@ impl LanguageModelRequest { match &self.model { LanguageModel::OpenAi(_) => {} LanguageModel::Anthropic(_) => {} + LanguageModel::Ollama(_) => {} LanguageModel::Cloud(model) => match model { CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => { preprocess_anthropic_request(self); diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 3efaff100d6ddfda7b7aaa5c684e3072622f9838..efc726fe2232ed329a519709e460c05591eaffde 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -2,6 +2,7 @@ use std::fmt; pub use anthropic::Model as AnthropicModel; use gpui::Pixels; +pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; use schemars::{ schema::{InstanceType, Metadata, Schema, SchemaObject}, @@ -168,6 +169,11 @@ pub enum AssistantProvider { api_url: String, low_speed_timeout_in_seconds: Option, }, + Ollama { + model: OllamaModel, + api_url: String, + low_speed_timeout_in_seconds: Option, + }, } impl Default for AssistantProvider { @@ -197,6 +203,12 @@ pub enum AssistantProviderContent { api_url: Option, low_speed_timeout_in_seconds: Option, }, + #[serde(rename = "ollama")] + Ollama { + default_model: Option, + api_url: Option, + low_speed_timeout_in_seconds: Option, + }, } #[derive(Debug, Default)] @@ -328,6 +340,13 @@ impl AssistantSettingsContent { low_speed_timeout_in_seconds: None, }) } + LanguageModel::Ollama(model) => { + *provider = Some(AssistantProviderContent::Ollama { + default_model: Some(model), + api_url: None, + low_speed_timeout_in_seconds: None, + }) + } }, }, }, @@ -472,6 +491,27 @@ impl Settings for AssistantSettings { Some(low_speed_timeout_in_seconds_override); } } + ( + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + AssistantProviderContent::Ollama { + default_model: model_override, + api_url: api_url_override, + low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, + }, + ) => { + merge(model, model_override); + merge(api_url, api_url_override); + if let Some(low_speed_timeout_in_seconds_override) = + low_speed_timeout_in_seconds_override + { + *low_speed_timeout_in_seconds = + Some(low_speed_timeout_in_seconds_override); + } + } ( AssistantProvider::Anthropic { model, @@ -519,6 +559,15 @@ impl Settings for AssistantSettings { .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()), low_speed_timeout_in_seconds, }, + AssistantProviderContent::Ollama { + default_model: model, + api_url, + low_speed_timeout_in_seconds, + } => AssistantProvider::Ollama { + model: model.unwrap_or_default(), + api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()), + low_speed_timeout_in_seconds, + }, }; } } diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 01ea6325ad3e39ab675054a212707480d2449904..78b22556aca7c306bc326770daa0553fed646bf4 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -2,12 +2,14 @@ mod anthropic; mod cloud; #[cfg(test)] mod fake; +mod ollama; mod open_ai; pub use anthropic::*; pub use cloud::*; #[cfg(test)] pub use fake::*; +pub use ollama::*; pub use open_ai::*; use crate::{ @@ -50,6 +52,17 @@ pub fn init(client: Arc, cx: &mut AppContext) { low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => CompletionProvider::Ollama(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )), }; cx.set_global(provider); @@ -87,6 +100,23 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, ); } + + ( + CompletionProvider::Ollama(provider), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ); + } + (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => { provider.update(model.clone(), settings_version); } @@ -130,6 +160,22 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, )); } + ( + _, + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )); + } } }) }) @@ -142,6 +188,7 @@ pub enum CompletionProvider { Cloud(CloudCompletionProvider), #[cfg(test)] Fake(FakeCompletionProvider), + Ollama(OllamaCompletionProvider), } impl gpui::Global for CompletionProvider {} @@ -165,6 +212,10 @@ impl CompletionProvider { .available_models() .map(LanguageModel::Cloud) .collect(), + CompletionProvider::Ollama(provider) => provider + .available_models() + .map(|model| LanguageModel::Ollama(model.clone())) + .collect(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -175,6 +226,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.settings_version(), CompletionProvider::Anthropic(provider) => provider.settings_version(), CompletionProvider::Cloud(provider) => provider.settings_version(), + CompletionProvider::Ollama(provider) => provider.settings_version(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -185,6 +237,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.is_authenticated(), CompletionProvider::Anthropic(provider) => provider.is_authenticated(), CompletionProvider::Cloud(provider) => provider.is_authenticated(), + CompletionProvider::Ollama(provider) => provider.is_authenticated(), #[cfg(test)] CompletionProvider::Fake(_) => true, } @@ -195,6 +248,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.authenticate(cx), CompletionProvider::Anthropic(provider) => provider.authenticate(cx), CompletionProvider::Cloud(provider) => provider.authenticate(cx), + CompletionProvider::Ollama(provider) => provider.authenticate(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -205,6 +259,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx), + CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -215,6 +270,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), CompletionProvider::Cloud(_) => Task::ready(Ok(())), + CompletionProvider::Ollama(provider) => provider.reset_credentials(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -225,6 +281,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()), + CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()), #[cfg(test)] CompletionProvider::Fake(_) => LanguageModel::default(), } @@ -239,6 +296,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx), + CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx), #[cfg(test)] CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), } @@ -252,6 +310,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.complete(request), CompletionProvider::Anthropic(provider) => provider.complete(request), CompletionProvider::Cloud(provider) => provider.complete(request), + CompletionProvider::Ollama(provider) => provider.complete(request), #[cfg(test)] CompletionProvider::Fake(provider) => provider.complete(), } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/assistant/src/completion_provider/ollama.rs new file mode 100644 index 0000000000000000000000000000000000000000..74524da6dd000b037eb6439a1d762c646c4cd748 --- /dev/null +++ b/crates/assistant/src/completion_provider/ollama.rs @@ -0,0 +1,246 @@ +use crate::{ + assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, +}; +use anyhow::Result; +use futures::StreamExt as _; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; +use gpui::{AnyView, AppContext, Task}; +use http::HttpClient; +use ollama::{ + get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole, +}; +use std::sync::Arc; +use std::time::Duration; +use ui::{prelude::*, ButtonLike, ElevationIndex}; + +const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; + +pub struct OllamaCompletionProvider { + api_url: String, + model: OllamaModel, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + available_models: Vec, +} + +impl OllamaCompletionProvider { + pub fn new( + model: OllamaModel, + api_url: String, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + ) -> Self { + Self { + api_url, + model, + http_client, + low_speed_timeout, + settings_version, + available_models: Default::default(), + } + } + + pub fn update( + &mut self, + model: OllamaModel, + api_url: String, + low_speed_timeout: Option, + settings_version: usize, + ) { + self.model = model; + self.api_url = api_url; + self.low_speed_timeout = low_speed_timeout; + self.settings_version = settings_version; + } + + pub fn available_models(&self) -> impl Iterator { + self.available_models.iter() + } + + pub fn settings_version(&self) -> usize { + self.settings_version + } + + pub fn is_authenticated(&self) -> bool { + !self.available_models.is_empty() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + self.fetch_models(cx) + } + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + self.fetch_models(cx) + } + + pub fn fetch_models(&self, cx: &AppContext) -> Task> { + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + + // As a proxy for the server being "authenticated", we'll check if its up by fetching the models + cx.spawn(|mut cx| async move { + let models = get_models(http_client.as_ref(), &api_url, None).await?; + + let mut models: Vec = models + .into_iter() + // Since there is no metadata from the Ollama API + // indicating which models are embedding models, + // simply filter out models with "-embed" in their name + .filter(|model| !model.name.contains("-embed")) + .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size)) + .collect(); + + models.sort_by(|a, b| a.name.cmp(&b.name)); + + cx.update_global::(|provider, _cx| { + if let CompletionProvider::Ollama(provider) = provider { + provider.available_models = models; + } + }) + }) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into() + } + + pub fn model(&self) -> OllamaModel { + self.model.clone() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + _cx: &AppContext, + ) -> BoxFuture<'static, Result> { + // There is no endpoint for this _yet_ in Ollama + // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 + let token_count = request + .messages + .iter() + .map(|msg| msg.content.chars().count()) + .sum::() + / 4; + + async move { Ok(token_count) }.boxed() + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_ollama_request(request); + + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + let low_speed_timeout = self.low_speed_timeout; + async move { + let request = + stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(delta) => { + let content = match delta.message { + ChatMessage::User { content } => content, + ChatMessage::Assistant { content } => content, + ChatMessage::System { content } => content, + }; + Some(Ok(content)) + } + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + + fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { + let model = match request.model { + LanguageModel::Ollama(model) => model, + _ => self.model(), + }; + + ChatRequest { + model: model.name, + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => ChatMessage::User { + content: msg.content, + }, + Role::Assistant => ChatMessage::Assistant { + content: msg.content, + }, + Role::System => ChatMessage::System { + content: msg.content, + }, + }) + .collect(), + keep_alive: model.keep_alive, + stream: true, + options: Some(ChatOptions { + num_ctx: Some(model.max_tokens), + stop: Some(request.stop), + temperature: Some(request.temperature), + ..Default::default() + }), + } + } +} + +impl From for ollama::Role { + fn from(val: Role) -> Self { + match val { + Role::User => OllamaRole::User, + Role::Assistant => OllamaRole::Assistant, + Role::System => OllamaRole::System, + } + } +} + +struct DownloadOllamaMessage {} + +impl DownloadOllamaMessage { + pub fn new(_cx: &mut ViewContext) -> Self { + Self {} + } + + fn render_download_button(&self, _cx: &mut ViewContext) -> impl IntoElement { + ButtonLike::new("download_ollama_button") + .style(ButtonStyle::Filled) + .size(ButtonSize::Large) + .layer(ElevationIndex::ModalSurface) + .child(Label::new("Get Ollama")) + .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL)) + } +} + +impl Render for DownloadOllamaMessage { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + v_flex() + .p_4() + .size_full() + .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large)) + .child( + h_flex() + .w_full() + .p_4() + .justify_center() + .child( + self.render_download_button(cx) + ) + ) + .into_any() + } +} diff --git a/crates/ollama/Cargo.toml b/crates/ollama/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..2ff329df00305c301892a67d4a1a016b96ee7930 --- /dev/null +++ b/crates/ollama/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "ollama" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/ollama.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http.workspace = true +isahc.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs new file mode 100644 index 0000000000000000000000000000000000000000..141d7fe000a88942fbdde6a210edd8694f83554c --- /dev/null +++ b/crates/ollama/src/ollama.rs @@ -0,0 +1,224 @@ +use anyhow::{anyhow, Context, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use isahc::config::Configurable; +use serde::{Deserialize, Serialize}; +use std::{convert::TryFrom, time::Duration}; + +pub const OLLAMA_API_URL: &str = "http://localhost:11434"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + pub name: String, + pub parameter_size: String, + pub max_tokens: usize, + pub keep_alive: Option, +} + +impl Model { + pub fn new(name: &str, parameter_size: &str) -> Self { + Self { + name: name.to_owned(), + parameter_size: parameter_size.to_owned(), + // todo: determine if there's an endpoint to find the max tokens + // I'm not seeing it in the API docs but it's on the model cards + max_tokens: 2048, + keep_alive: Some("10m".to_owned()), + } + } + + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + &self.name + } + + pub fn max_token_count(&self) -> usize { + self.max_tokens + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatMessage { + Assistant { content: String }, + User { content: String }, + System { content: String }, +} + +#[derive(Serialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub keep_alive: Option, + pub options: Option, +} + +// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +#[derive(Serialize, Default)] +pub struct ChatOptions { + pub num_ctx: Option, + pub num_predict: Option, + pub stop: Option>, + pub temperature: Option, + pub top_p: Option, +} + +#[derive(Deserialize)] +pub struct ChatResponseDelta { + #[allow(unused)] + pub model: String, + #[allow(unused)] + pub created_at: String, + pub message: ChatMessage, + #[allow(unused)] + pub done_reason: Option, + #[allow(unused)] + pub done: bool, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModelsResponse { + pub models: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModelListing { + pub name: String, + pub modified_at: String, + pub size: u64, + pub digest: String, + pub details: ModelDetails, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModel { + pub modelfile: String, + pub parameters: String, + pub template: String, + pub details: ModelDetails, +} + +#[derive(Serialize, Deserialize)] +pub struct ModelDetails { + pub format: String, + pub family: String, + pub families: Option>, + pub parameter_size: String, + pub quantization_level: String, +} + +pub async fn stream_chat_completion( + client: &dyn HttpClient, + api_url: &str, + request: ChatRequest, + low_speed_timeout: Option, +) -> Result>> { + let uri = format!("{api_url}/api/chat"); + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + Some(serde_json::from_str(&line).context("Unable to parse chat response")) + } + Err(e) => Some(Err(e.into())), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +} + +pub async fn get_models( + client: &dyn HttpClient, + api_url: &str, + low_speed_timeout: Option, +) -> Result> { + let uri = format!("{api_url}/api/tags"); + let mut request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Accept", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::default())?; + + let mut response = client.send(request).await?; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: LocalModelsResponse = + serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?; + + Ok(response.models) + } else { + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +}