Ollama max_tokens settings (#17025)

Peter Tripp created

- Support `available_models` for Ollama
- Clamp default max tokens (context length) to 16384.
- Add documentation for ollama context configuration.

Change summary

crates/assistant/src/assistant_settings.rs   |  3 
crates/language_model/src/provider/ollama.rs | 43 +++++++++++++++++++-
crates/language_model/src/settings.rs        |  8 +++
crates/ollama/src/ollama.rs                  | 45 ++++++++++-----------
docs/src/assistant/configuration.md          | 28 ++++++++++--
5 files changed, 92 insertions(+), 35 deletions(-)

Detailed changes

crates/assistant/src/assistant_settings.rs 🔗

@@ -135,6 +135,7 @@ impl AssistantSettingsContent {
                                         Some(language_model::settings::OllamaSettingsContent {
                                             api_url,
                                             low_speed_timeout_in_seconds,
+                                            available_models: None,
                                         });
                                 }
                             },
@@ -295,7 +296,7 @@ impl AssistantSettingsContent {
                             _ => (None, None),
                         };
                         settings.provider = Some(AssistantProviderContentV1::Ollama {
-                            default_model: Some(ollama::Model::new(&model)),
+                            default_model: Some(ollama::Model::new(&model, None, None)),
                             api_url,
                             low_speed_timeout_in_seconds,
                         });

crates/language_model/src/provider/ollama.rs 🔗

@@ -6,8 +6,10 @@ use ollama::{
     get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
     ChatResponseDelta, OllamaToolCall,
 };
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{collections::BTreeMap, sync::Arc, time::Duration};
 use ui::{prelude::*, ButtonLike, Indicator};
 use util::ResultExt;
 
@@ -28,6 +30,17 @@ const PROVIDER_NAME: &str = "Ollama";
 pub struct OllamaSettings {
     pub api_url: String,
     pub low_speed_timeout: Option<Duration>,
+    pub available_models: Vec<AvailableModel>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+    /// The model name in the Ollama API (e.g. "llama3.1:latest")
+    pub name: String,
+    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
+    pub display_name: Option<String>,
+    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
+    pub max_tokens: usize,
 }
 
 pub struct OllamaLanguageModelProvider {
@@ -61,7 +74,7 @@ impl State {
                 // indicating which models are embedding models,
                 // simply filter out models with "-embed" in their name
                 .filter(|model| !model.name.contains("-embed"))
-                .map(|model| ollama::Model::new(&model.name))
+                .map(|model| ollama::Model::new(&model.name, None, None))
                 .collect();
 
             models.sort_by(|a, b| a.name.cmp(&b.name));
@@ -123,10 +136,32 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
     }
 
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
-        self.state
-            .read(cx)
+        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
+
+        // Add models from the Ollama API
+        for model in self.state.read(cx).available_models.iter() {
+            models.insert(model.name.clone(), model.clone());
+        }
+
+        // Override with available models from settings
+        for model in AllLanguageModelSettings::get_global(cx)
+            .ollama
             .available_models
             .iter()
+        {
+            models.insert(
+                model.name.clone(),
+                ollama::Model {
+                    name: model.name.clone(),
+                    display_name: model.display_name.clone(),
+                    max_tokens: model.max_tokens,
+                    keep_alive: None,
+                },
+            );
+        }
+
+        models
+            .into_values()
             .map(|model| {
                 Arc::new(OllamaLanguageModel {
                     id: LanguageModelId::from(model.name.clone()),

crates/language_model/src/settings.rs 🔗

@@ -152,6 +152,7 @@ pub struct AnthropicSettingsContentV1 {
 pub struct OllamaSettingsContent {
     pub api_url: Option<String>,
     pub low_speed_timeout_in_seconds: Option<u64>,
+    pub available_models: Option<Vec<provider::ollama::AvailableModel>>,
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -276,6 +277,9 @@ impl settings::Settings for AllLanguageModelSettings {
                 anthropic.as_ref().and_then(|s| s.available_models.clone()),
             );
 
+            // Ollama
+            let ollama = value.ollama.clone();
+
             merge(
                 &mut settings.ollama.api_url,
                 value.ollama.as_ref().and_then(|s| s.api_url.clone()),
@@ -288,6 +292,10 @@ impl settings::Settings for AllLanguageModelSettings {
                 settings.ollama.low_speed_timeout =
                     Some(Duration::from_secs(low_speed_timeout_in_seconds));
             }
+            merge(
+                &mut settings.ollama.available_models,
+                ollama.as_ref().and_then(|s| s.available_models.clone()),
+            );
 
             // OpenAI
             let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {

crates/ollama/src/ollama.rs 🔗

@@ -66,40 +66,37 @@ impl Default for KeepAlive {
 #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 pub struct Model {
     pub name: String,
+    pub display_name: Option<String>,
     pub max_tokens: usize,
     pub keep_alive: Option<KeepAlive>,
 }
 
-// This could be dynamically retrieved via the API (1 call per model)
-// curl -s http://localhost:11434/api/show -d '{"model": "llama3.1:latest"}' | jq '.model_info."llama.context_length"'
 fn get_max_tokens(name: &str) -> usize {
-    match name {
-        "dolphin-llama3:8b-256k" => 262144, // 256K
-        _ => match name.split(':').next().unwrap() {
-            "mistral-nemo" => 1024000,                                      // 1M
-            "deepseek-coder-v2" => 163840,                                  // 160K
-            "llama3.1" | "phi3" | "command-r" | "command-r-plus" => 131072, // 128K
-            "codeqwen" => 65536,                                            // 64K
-            "mistral" | "mistral-large" | "dolphin-mistral" | "codestral"   // 32K
-            | "mistral-openorca" | "dolphin-mixtral" | "mixstral" | "llava"
-            | "qwen" | "qwen2" | "wizardlm2" | "wizard-math" => 32768,
-            "codellama" | "stable-code" | "deepseek-coder" | "starcoder2"   // 16K
-            | "wizardcoder" => 16384,
-            "llama3" | "gemma2" | "gemma" | "codegemma" | "dolphin-llama3"  // 8K
-            | "llava-llama3" | "starcoder" | "openchat" | "aya" => 8192,
-            "llama2" | "yi" | "llama2-chinese" | "vicuna" | "nous-hermes2"  // 4K
-            | "stablelm2" => 4096,
-            "phi" | "orca-mini" | "tinyllama" | "granite-code" => 2048,     // 2K
-            _ => 2048,                                                      // 2K (default)
-        },
+    /// Default context length for unknown models.
+    const DEFAULT_TOKENS: usize = 2048;
+    /// Magic number. Lets many Ollama models work with ~16GB of ram.
+    const MAXIMUM_TOKENS: usize = 16384;
+
+    match name.split(':').next().unwrap() {
+        "phi" | "tinyllama" | "granite-code" => 2048,
+        "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
+        "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
+        "codellama" | "starcoder2" => 16384,
+        "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
+        "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" => 128000,
+        _ => DEFAULT_TOKENS,
     }
+    .clamp(1, MAXIMUM_TOKENS)
 }
 
 impl Model {
-    pub fn new(name: &str) -> Self {
+    pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
         Self {
             name: name.to_owned(),
-            max_tokens: get_max_tokens(name),
+            display_name: display_name
+                .map(ToString::to_string)
+                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
+            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
             keep_alive: Some(KeepAlive::indefinite()),
         }
     }
@@ -109,7 +106,7 @@ impl Model {
     }
 
     pub fn display_name(&self) -> &str {
-        &self.name
+        self.display_name.as_ref().unwrap_or(&self.name)
     }
 
     pub fn max_token_count(&self) -> usize {

docs/src/assistant/configuration.md 🔗

@@ -108,33 +108,49 @@ Custom models will be listed in the model dropdown in the assistant panel.
 
 Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`.
 
-You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
-
-1. Download, for example, the `mistral` model with Ollama:
+1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`:
 
    ```sh
    ollama pull mistral
    ```
 
-2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching:
+2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (MacOS) or launching:
 
    ```sh
    ollama serve
    ```
 
 3. In the assistant panel, select one of the Ollama models using the model dropdown.
-4. (Optional) If you want to change the default URL that is used to access the Ollama server, you can do so by adding the following settings:
+
+4. (Optional) Specify a [custom api_url](#custom-endpoint) or [custom `low_speed_timeout_in_seconds`](#provider-timeout) if required.
+
+#### Ollama Context Length {#ollama-context}}
+
+Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. Zed API requests to Ollama include this as `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of ram are able to use most models out of the box. See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults.
+
+**Note**: Tokens counts displayed in the assistant panel are only estimates and will differ from the models native tokenizer.
+
+Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json:
 
 ```json
 {
   "language_models": {
     "ollama": {
-      "api_url": "http://localhost:11434"
+      "low_speed_timeout_in_seconds": 120,
+      "available_models": [
+        {
+          "provider": "ollama",
+          "name": "mistral:latest",
+          "max_tokens": 32768
+        }
+      ]
     }
   }
 }
 ```
 
+If you specify a context length that is too large for your hardware, Ollama will log an error. You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (MacOS) or `journalctl -u ollama -f` (Linux). Depending on the memory available on your machine, you may need to adjust the context length to a smaller value.
+
 ### OpenAI {#openai}
 
 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys)