Ollama improvements (#12921)

Kyle Kelley created

Attempt to load the model early on when the user has switched the model.

This is a follow up to #12902

Release Notes:

- N/A

Change summary

crates/assistant/src/completion_provider.rs        |  3 +
crates/assistant/src/completion_provider/ollama.rs | 25 ++++++++
crates/ollama/src/ollama.rs                        | 46 ++++++++++++++-
3 files changed, 67 insertions(+), 7 deletions(-)

Detailed changes

crates/assistant/src/completion_provider.rs 🔗

@@ -62,6 +62,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
             client.http_client(),
             low_speed_timeout_in_seconds.map(Duration::from_secs),
             settings_version,
+            cx,
         )),
     };
     cx.set_global(provider);
@@ -114,6 +115,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                         api_url.clone(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
                         settings_version,
+                        cx,
                     );
                 }
 
@@ -174,6 +176,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                         client.http_client(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
                         settings_version,
+                        cx,
                     ));
                 }
             }

crates/assistant/src/completion_provider/ollama.rs 🔗

@@ -7,7 +7,8 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
 use gpui::{AnyView, AppContext, Task};
 use http::HttpClient;
 use ollama::{
-    get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
+    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
+    Role as OllamaRole,
 };
 use std::sync::Arc;
 use std::time::Duration;
@@ -31,7 +32,17 @@ impl OllamaCompletionProvider {
         http_client: Arc<dyn HttpClient>,
         low_speed_timeout: Option<Duration>,
         settings_version: usize,
+        cx: &AppContext,
     ) -> Self {
+        cx.spawn({
+            let api_url = api_url.clone();
+            let client = http_client.clone();
+            let model = model.name.clone();
+
+            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
+        })
+        .detach_and_log_err(cx);
+
         Self {
             api_url,
             model,
@@ -48,7 +59,17 @@ impl OllamaCompletionProvider {
         api_url: String,
         low_speed_timeout: Option<Duration>,
         settings_version: usize,
+        cx: &AppContext,
     ) {
+        cx.spawn({
+            let api_url = api_url.clone();
+            let client = self.http_client.clone();
+            let model = model.name.clone();
+
+            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
+        })
+        .detach_and_log_err(cx);
+
         self.model = model;
         self.api_url = api_url;
         self.low_speed_timeout = low_speed_timeout;
@@ -93,7 +114,7 @@ impl OllamaCompletionProvider {
                 // indicating which models are embedding models,
                 // simply filter out models with "-embed" in their name
                 .filter(|model| !model.name.contains("-embed"))
-                .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
+                .map(|model| OllamaModel::new(&model.name))
                 .collect();
 
             models.sort_by(|a, b| a.name.cmp(&b.name));

crates/ollama/src/ollama.rs 🔗

@@ -42,18 +42,14 @@ impl From<Role> for String {
 #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 pub struct Model {
     pub name: String,
-    pub parameter_size: String,
     pub max_tokens: usize,
     pub keep_alive: Option<String>,
 }
 
 impl Model {
-    pub fn new(name: &str, parameter_size: &str) -> Self {
+    pub fn new(name: &str) -> Self {
         Self {
             name: name.to_owned(),
-            parameter_size: parameter_size.to_owned(),
-            // todo: determine if there's an endpoint to find the max tokens
-            //       I'm not seeing it in the API docs but it's on the model cards
             max_tokens: 2048,
             keep_alive: Some("10m".to_owned()),
         }
@@ -222,3 +218,43 @@ pub async fn get_models(
         ))
     }
 }
+
+/// Sends an empty request to Ollama to trigger loading the model
+pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
+    let uri = format!("{api_url}/api/generate");
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json")
+        .body(AsyncBody::from(serde_json::to_string(
+            &serde_json::json!({
+                "model": model,
+                "keep_alive": "15m",
+            }),
+        )?))?;
+
+    let mut response = match client.send(request).await {
+        Ok(response) => response,
+        Err(err) => {
+            // Be ok with a timeout during preload of the model
+            if err.is_timeout() {
+                return Ok(());
+            } else {
+                return Err(err.into());
+            }
+        }
+    };
+
+    if response.status().is_success() {
+        Ok(())
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        Err(anyhow!(
+            "Failed to connect to Ollama API: {} {}",
+            response.status(),
+            body,
+        ))
+    }
+}