Detailed changes
@@ -62,6 +62,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
+ cx,
)),
};
cx.set_global(provider);
@@ -114,6 +115,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
+ cx,
);
}
@@ -174,6 +176,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
+ cx,
));
}
}
@@ -7,7 +7,8 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task};
use http::HttpClient;
use ollama::{
- get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
+ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
+ Role as OllamaRole,
};
use std::sync::Arc;
use std::time::Duration;
@@ -31,7 +32,17 @@ impl OllamaCompletionProvider {
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
+ cx: &AppContext,
) -> Self {
+ cx.spawn({
+ let api_url = api_url.clone();
+ let client = http_client.clone();
+ let model = model.name.clone();
+
+ |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
+ })
+ .detach_and_log_err(cx);
+
Self {
api_url,
model,
@@ -48,7 +59,17 @@ impl OllamaCompletionProvider {
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
+ cx: &AppContext,
) {
+ cx.spawn({
+ let api_url = api_url.clone();
+ let client = self.http_client.clone();
+ let model = model.name.clone();
+
+ |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
+ })
+ .detach_and_log_err(cx);
+
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
@@ -93,7 +114,7 @@ impl OllamaCompletionProvider {
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
- .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
+ .map(|model| OllamaModel::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
@@ -42,18 +42,14 @@ impl From<Role> for String {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Model {
pub name: String,
- pub parameter_size: String,
pub max_tokens: usize,
pub keep_alive: Option<String>,
}
impl Model {
- pub fn new(name: &str, parameter_size: &str) -> Self {
+ pub fn new(name: &str) -> Self {
Self {
name: name.to_owned(),
- parameter_size: parameter_size.to_owned(),
- // todo: determine if there's an endpoint to find the max tokens
- // I'm not seeing it in the API docs but it's on the model cards
max_tokens: 2048,
keep_alive: Some("10m".to_owned()),
}
@@ -222,3 +218,43 @@ pub async fn get_models(
))
}
}
+
+/// Sends an empty request to Ollama to trigger loading the model
+pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
+ let uri = format!("{api_url}/api/generate");
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(serde_json::to_string(
+ &serde_json::json!({
+ "model": model,
+ "keep_alive": "15m",
+ }),
+ )?))?;
+
+ let mut response = match client.send(request).await {
+ Ok(response) => response,
+ Err(err) => {
+ // Be ok with a timeout during preload of the model
+ if err.is_timeout() {
+ return Ok(());
+ } else {
+ return Err(err.into());
+ }
+ }
+ };
+
+ if response.status().is_success() {
+ Ok(())
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ Err(anyhow!(
+ "Failed to connect to Ollama API: {} {}",
+ response.status(),
+ body,
+ ))
+ }
+}