Detailed changes
@@ -359,6 +359,7 @@ dependencies = [
"log",
"menu",
"multi_buffer",
+ "ollama",
"open_ai",
"ordered-float 2.10.0",
"parking_lot",
@@ -6921,6 +6922,19 @@ dependencies = [
"cc",
]
+[[package]]
+name = "ollama"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "http 0.1.0",
+ "isahc",
+ "schemars",
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "once_cell"
version = "1.19.0"
@@ -61,6 +61,7 @@ members = [
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
+ "crates/ollama",
"crates/open_ai",
"crates/outline",
"crates/picker",
@@ -207,6 +208,7 @@ menu = { path = "crates/menu" }
multi_buffer = { path = "crates/multi_buffer" }
node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
+ollama = { path = "crates/ollama" }
open_ai = { path = "crates/open_ai" }
outline = { path = "crates/outline" }
picker = { path = "crates/picker" }
@@ -35,6 +35,7 @@ language.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
+ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
ordered-float.workspace = true
parking_lot.workspace = true
@@ -12,7 +12,7 @@ mod streaming_diff;
pub use assistant_panel::AssistantPanel;
-use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
+use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
@@ -91,6 +91,7 @@ pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
+ Ollama(OllamaModel),
}
impl Default for LanguageModel {
@@ -105,6 +106,7 @@ impl LanguageModel {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
+ LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
@@ -113,6 +115,7 @@ impl LanguageModel {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
+ LanguageModel::Ollama(model) => model.display_name().into(),
}
}
@@ -121,6 +124,7 @@ impl LanguageModel {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
+ LanguageModel::Ollama(model) => model.max_token_count(),
}
}
@@ -129,6 +133,7 @@ impl LanguageModel {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
+ LanguageModel::Ollama(model) => model.id(),
}
}
}
@@ -179,6 +184,7 @@ impl LanguageModelRequest {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
+ LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
preprocess_anthropic_request(self);
@@ -2,6 +2,7 @@ use std::fmt;
pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
+pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
@@ -168,6 +169,11 @@ pub enum AssistantProvider {
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
+ Ollama {
+ model: OllamaModel,
+ api_url: String,
+ low_speed_timeout_in_seconds: Option<u64>,
+ },
}
impl Default for AssistantProvider {
@@ -197,6 +203,12 @@ pub enum AssistantProviderContent {
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
+ #[serde(rename = "ollama")]
+ Ollama {
+ default_model: Option<OllamaModel>,
+ api_url: Option<String>,
+ low_speed_timeout_in_seconds: Option<u64>,
+ },
}
#[derive(Debug, Default)]
@@ -328,6 +340,13 @@ impl AssistantSettingsContent {
low_speed_timeout_in_seconds: None,
})
}
+ LanguageModel::Ollama(model) => {
+ *provider = Some(AssistantProviderContent::Ollama {
+ default_model: Some(model),
+ api_url: None,
+ low_speed_timeout_in_seconds: None,
+ })
+ }
},
},
},
@@ -472,6 +491,27 @@ impl Settings for AssistantSettings {
Some(low_speed_timeout_in_seconds_override);
}
}
+ (
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ },
+ AssistantProviderContent::Ollama {
+ default_model: model_override,
+ api_url: api_url_override,
+ low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
+ },
+ ) => {
+ merge(model, model_override);
+ merge(api_url, api_url_override);
+ if let Some(low_speed_timeout_in_seconds_override) =
+ low_speed_timeout_in_seconds_override
+ {
+ *low_speed_timeout_in_seconds =
+ Some(low_speed_timeout_in_seconds_override);
+ }
+ }
(
AssistantProvider::Anthropic {
model,
@@ -519,6 +559,15 @@ impl Settings for AssistantSettings {
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
low_speed_timeout_in_seconds,
},
+ AssistantProviderContent::Ollama {
+ default_model: model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => AssistantProvider::Ollama {
+ model: model.unwrap_or_default(),
+ api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
+ low_speed_timeout_in_seconds,
+ },
};
}
}
@@ -2,12 +2,14 @@ mod anthropic;
mod cloud;
#[cfg(test)]
mod fake;
+mod ollama;
mod open_ai;
pub use anthropic::*;
pub use cloud::*;
#[cfg(test)]
pub use fake::*;
+pub use ollama::*;
pub use open_ai::*;
use crate::{
@@ -50,6 +52,17 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
+ model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ )),
};
cx.set_global(provider);
@@ -87,6 +100,23 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
);
}
+
+ (
+ CompletionProvider::Ollama(provider),
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ },
+ ) => {
+ provider.update(
+ model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ );
+ }
+
(CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
provider.update(model.clone(), settings_version);
}
@@ -130,6 +160,22 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
));
}
+ (
+ _,
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ },
+ ) => {
+ *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
+ model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ ));
+ }
}
})
})
@@ -142,6 +188,7 @@ pub enum CompletionProvider {
Cloud(CloudCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
+ Ollama(OllamaCompletionProvider),
}
impl gpui::Global for CompletionProvider {}
@@ -165,6 +212,10 @@ impl CompletionProvider {
.available_models()
.map(LanguageModel::Cloud)
.collect(),
+ CompletionProvider::Ollama(provider) => provider
+ .available_models()
+ .map(|model| LanguageModel::Ollama(model.clone()))
+ .collect(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@@ -175,6 +226,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::Anthropic(provider) => provider.settings_version(),
CompletionProvider::Cloud(provider) => provider.settings_version(),
+ CompletionProvider::Ollama(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@@ -185,6 +237,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
CompletionProvider::Cloud(provider) => provider.is_authenticated(),
+ CompletionProvider::Ollama(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
}
@@ -195,6 +248,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
CompletionProvider::Cloud(provider) => provider.authenticate(cx),
+ CompletionProvider::Ollama(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
@@ -205,6 +259,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
+ CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@@ -215,6 +270,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
CompletionProvider::Cloud(_) => Task::ready(Ok(())),
+ CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
@@ -225,6 +281,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
+ CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
#[cfg(test)]
CompletionProvider::Fake(_) => LanguageModel::default(),
}
@@ -239,6 +296,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
+ CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
}
@@ -252,6 +310,7 @@ impl CompletionProvider {
CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::Anthropic(provider) => provider.complete(request),
CompletionProvider::Cloud(provider) => provider.complete(request),
+ CompletionProvider::Ollama(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
}
@@ -0,0 +1,246 @@
+use crate::{
+ assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
+};
+use anyhow::Result;
+use futures::StreamExt as _;
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
+use gpui::{AnyView, AppContext, Task};
+use http::HttpClient;
+use ollama::{
+ get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
+};
+use std::sync::Arc;
+use std::time::Duration;
+use ui::{prelude::*, ButtonLike, ElevationIndex};
+
+const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
+
+pub struct OllamaCompletionProvider {
+ api_url: String,
+ model: OllamaModel,
+ http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ available_models: Vec<OllamaModel>,
+}
+
+impl OllamaCompletionProvider {
+ pub fn new(
+ model: OllamaModel,
+ api_url: String,
+ http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) -> Self {
+ Self {
+ api_url,
+ model,
+ http_client,
+ low_speed_timeout,
+ settings_version,
+ available_models: Default::default(),
+ }
+ }
+
+ pub fn update(
+ &mut self,
+ model: OllamaModel,
+ api_url: String,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) {
+ self.model = model;
+ self.api_url = api_url;
+ self.low_speed_timeout = low_speed_timeout;
+ self.settings_version = settings_version;
+ }
+
+ pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
+ self.available_models.iter()
+ }
+
+ pub fn settings_version(&self) -> usize {
+ self.settings_version
+ }
+
+ pub fn is_authenticated(&self) -> bool {
+ !self.available_models.is_empty()
+ }
+
+ pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ if self.is_authenticated() {
+ Task::ready(Ok(()))
+ } else {
+ self.fetch_models(cx)
+ }
+ }
+
+ pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ self.fetch_models(cx)
+ }
+
+ pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
+ let http_client = self.http_client.clone();
+ let api_url = self.api_url.clone();
+
+ // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
+ cx.spawn(|mut cx| async move {
+ let models = get_models(http_client.as_ref(), &api_url, None).await?;
+
+ let mut models: Vec<OllamaModel> = models
+ .into_iter()
+ // Since there is no metadata from the Ollama API
+ // indicating which models are embedding models,
+ // simply filter out models with "-embed" in their name
+ .filter(|model| !model.name.contains("-embed"))
+ .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
+ .collect();
+
+ models.sort_by(|a, b| a.name.cmp(&b.name));
+
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::Ollama(provider) = provider {
+ provider.available_models = models;
+ }
+ })
+ })
+ }
+
+ pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into()
+ }
+
+ pub fn model(&self) -> OllamaModel {
+ self.model.clone()
+ }
+
+ pub fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ _cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ // There is no endpoint for this _yet_ in Ollama
+ // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
+ let token_count = request
+ .messages
+ .iter()
+ .map(|msg| msg.content.chars().count())
+ .sum::<usize>()
+ / 4;
+
+ async move { Ok(token_count) }.boxed()
+ }
+
+ pub fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = self.to_ollama_request(request);
+
+ let http_client = self.http_client.clone();
+ let api_url = self.api_url.clone();
+ let low_speed_timeout = self.low_speed_timeout;
+ async move {
+ let request =
+ stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(delta) => {
+ let content = match delta.message {
+ ChatMessage::User { content } => content,
+ ChatMessage::Assistant { content } => content,
+ ChatMessage::System { content } => content,
+ };
+ Some(Ok(content))
+ }
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+
+ fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
+ let model = match request.model {
+ LanguageModel::Ollama(model) => model,
+ _ => self.model(),
+ };
+
+ ChatRequest {
+ model: model.name,
+ messages: request
+ .messages
+ .into_iter()
+ .map(|msg| match msg.role {
+ Role::User => ChatMessage::User {
+ content: msg.content,
+ },
+ Role::Assistant => ChatMessage::Assistant {
+ content: msg.content,
+ },
+ Role::System => ChatMessage::System {
+ content: msg.content,
+ },
+ })
+ .collect(),
+ keep_alive: model.keep_alive,
+ stream: true,
+ options: Some(ChatOptions {
+ num_ctx: Some(model.max_tokens),
+ stop: Some(request.stop),
+ temperature: Some(request.temperature),
+ ..Default::default()
+ }),
+ }
+ }
+}
+
+impl From<Role> for ollama::Role {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => OllamaRole::User,
+ Role::Assistant => OllamaRole::Assistant,
+ Role::System => OllamaRole::System,
+ }
+ }
+}
+
+struct DownloadOllamaMessage {}
+
+impl DownloadOllamaMessage {
+ pub fn new(_cx: &mut ViewContext<Self>) -> Self {
+ Self {}
+ }
+
+ fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+ ButtonLike::new("download_ollama_button")
+ .style(ButtonStyle::Filled)
+ .size(ButtonSize::Large)
+ .layer(ElevationIndex::ModalSurface)
+ .child(Label::new("Get Ollama"))
+ .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
+ }
+}
+
+impl Render for DownloadOllamaMessage {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ v_flex()
+ .p_4()
+ .size_full()
+ .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large))
+ .child(
+ h_flex()
+ .w_full()
+ .p_4()
+ .justify_center()
+ .child(
+ self.render_download_button(cx)
+ )
+ )
+ .into_any()
+ }
+}
@@ -0,0 +1,22 @@
+[package]
+name = "ollama"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lib]
+path = "src/ollama.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+http.workspace = true
+isahc.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
@@ -0,0 +1,224 @@
+use anyhow::{anyhow, Context, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use isahc::config::Configurable;
+use serde::{Deserialize, Serialize};
+use std::{convert::TryFrom, time::Duration};
+
+pub const OLLAMA_API_URL: &str = "http://localhost:11434";
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl TryFrom<String> for Role {
+ type Error = anyhow::Error;
+
+ fn try_from(value: String) -> Result<Self> {
+ match value.as_str() {
+ "user" => Ok(Self::User),
+ "assistant" => Ok(Self::Assistant),
+ "system" => Ok(Self::System),
+ _ => Err(anyhow!("invalid role '{value}'")),
+ }
+ }
+}
+
+impl From<Role> for String {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => "user".to_owned(),
+ Role::Assistant => "assistant".to_owned(),
+ Role::System => "system".to_owned(),
+ }
+ }
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct Model {
+ pub name: String,
+ pub parameter_size: String,
+ pub max_tokens: usize,
+ pub keep_alive: Option<String>,
+}
+
+impl Model {
+ pub fn new(name: &str, parameter_size: &str) -> Self {
+ Self {
+ name: name.to_owned(),
+ parameter_size: parameter_size.to_owned(),
+ // todo: determine if there's an endpoint to find the max tokens
+ // I'm not seeing it in the API docs but it's on the model cards
+ max_tokens: 2048,
+ keep_alive: Some("10m".to_owned()),
+ }
+ }
+
+ pub fn id(&self) -> &str {
+ &self.name
+ }
+
+ pub fn display_name(&self) -> &str {
+ &self.name
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ self.max_tokens
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "role", rename_all = "lowercase")]
+pub enum ChatMessage {
+ Assistant { content: String },
+ User { content: String },
+ System { content: String },
+}
+
+#[derive(Serialize)]
+pub struct ChatRequest {
+ pub model: String,
+ pub messages: Vec<ChatMessage>,
+ pub stream: bool,
+ pub keep_alive: Option<String>,
+ pub options: Option<ChatOptions>,
+}
+
+// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
+#[derive(Serialize, Default)]
+pub struct ChatOptions {
+ pub num_ctx: Option<usize>,
+ pub num_predict: Option<isize>,
+ pub stop: Option<Vec<String>>,
+ pub temperature: Option<f32>,
+ pub top_p: Option<f32>,
+}
+
+#[derive(Deserialize)]
+pub struct ChatResponseDelta {
+ #[allow(unused)]
+ pub model: String,
+ #[allow(unused)]
+ pub created_at: String,
+ pub message: ChatMessage,
+ #[allow(unused)]
+ pub done_reason: Option<String>,
+ #[allow(unused)]
+ pub done: bool,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct LocalModelsResponse {
+ pub models: Vec<LocalModelListing>,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct LocalModelListing {
+ pub name: String,
+ pub modified_at: String,
+ pub size: u64,
+ pub digest: String,
+ pub details: ModelDetails,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct LocalModel {
+ pub modelfile: String,
+ pub parameters: String,
+ pub template: String,
+ pub details: ModelDetails,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct ModelDetails {
+ pub format: String,
+ pub family: String,
+ pub families: Option<Vec<String>>,
+ pub parameter_size: String,
+ pub quantization_level: String,
+}
+
+pub async fn stream_chat_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ request: ChatRequest,
+ low_speed_timeout: Option<Duration>,
+) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
+ let uri = format!("{api_url}/api/chat");
+ let mut request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json");
+
+ if let Some(low_speed_timeout) = low_speed_timeout {
+ request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+ };
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ Some(serde_json::from_str(&line).context("Unable to parse chat response"))
+ }
+ Err(e) => Some(Err(e.into())),
+ }
+ })
+ .boxed())
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ Err(anyhow!(
+ "Failed to connect to Ollama API: {} {}",
+ response.status(),
+ body,
+ ))
+ }
+}
+
+pub async fn get_models(
+ client: &dyn HttpClient,
+ api_url: &str,
+ low_speed_timeout: Option<Duration>,
+) -> Result<Vec<LocalModelListing>> {
+ let uri = format!("{api_url}/api/tags");
+ let mut request_builder = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Accept", "application/json");
+
+ if let Some(low_speed_timeout) = low_speed_timeout {
+ request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+ };
+
+ let request = request_builder.body(AsyncBody::default())?;
+
+ let mut response = client.send(request).await?;
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ if response.status().is_success() {
+ let response: LocalModelsResponse =
+ serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
+
+ Ok(response.models)
+ } else {
+ Err(anyhow!(
+ "Failed to connect to Ollama API: {} {}",
+ response.status(),
+ body,
+ ))
+ }
+}