ollama: Resolve context window size via API (#39941)

Bennet Bo Fenner created

Previously we were guessing the context window size here:
https://github.com/zed-industries/zed/blob/8c3f09e31e3588a2494042dfe77bb3dccab1f7ba/crates/ollama/src/ollama.rs#L22

This is inaccurate and must be updated manually. This PR ensures that we
extract the context window size from the request in the same way that
the Ollama CLI does when running `ollama show <model-name>` (Relevant
code is
[here](https://github.com/ollama/ollama/blob/3d32249c749c6f77c1dc8a7cb55ae74fc2f4c08b/cmd/cmd.go#L860))

The format looks like this:

```json
{
  "model_info": {
    "general.architecture": "llama",
    "llama.context_length": 132000
  }
}
```

Once this PR is merged we could technically remove the old code
https://github.com/zed-industries/zed/blob/8c3f09e31e3588a2494042dfe77bb3dccab1f7ba/crates/ollama/src/ollama.rs#L22
I decided to keep it for now, as it is unclear if the necessary fields
are available via the API on older Ollama versions.

Release Notes:

- Fixed an issue where Ollama models would use the wrong context window
size

Change summary

crates/language_models/src/provider/ollama.rs | 10 +-
crates/ollama/src/ollama.rs                   | 71 ++++++++++++++++++++
2 files changed, 74 insertions(+), 7 deletions(-)

Detailed changes

crates/language_models/src/provider/ollama.rs 🔗

@@ -119,16 +119,16 @@ impl State {
                     let api_key = api_key.clone();
                     async move {
                         let name = model.name.as_str();
-                        let capabilities =
+                        let model =
                             show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name)
                                 .await?;
                         let ollama_model = ollama::Model::new(
                             name,
                             None,
-                            None,
-                            Some(capabilities.supports_tools()),
-                            Some(capabilities.supports_vision()),
-                            Some(capabilities.supports_thinking()),
+                            model.context_length,
+                            Some(model.supports_tools()),
+                            Some(model.supports_vision()),
+                            Some(model.supports_thinking()),
                         );
                         Ok(ollama_model)
                     }

crates/ollama/src/ollama.rs 🔗

@@ -189,10 +189,74 @@ pub struct ModelDetails {
     pub quantization_level: String,
 }
 
-#[derive(Deserialize, Debug)]
+#[derive(Debug)]
 pub struct ModelShow {
-    #[serde(default)]
     pub capabilities: Vec<String>,
+    pub context_length: Option<u64>,
+    pub architecture: Option<String>,
+}
+
+impl<'de> Deserialize<'de> for ModelShow {
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        use serde::de::{self, MapAccess, Visitor};
+        use std::fmt;
+
+        struct ModelShowVisitor;
+
+        impl<'de> Visitor<'de> for ModelShowVisitor {
+            type Value = ModelShow;
+
+            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+                formatter.write_str("a ModelShow object")
+            }
+
+            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
+            where
+                A: MapAccess<'de>,
+            {
+                let mut capabilities: Vec<String> = Vec::new();
+                let mut architecture: Option<String> = None;
+                let mut context_length: Option<u64> = None;
+
+                while let Some(key) = map.next_key::<String>()? {
+                    match key.as_str() {
+                        "capabilities" => {
+                            capabilities = map.next_value()?;
+                        }
+                        "model_info" => {
+                            let model_info: Value = map.next_value()?;
+                            if let Value::Object(obj) = model_info {
+                                architecture = obj
+                                    .get("general.architecture")
+                                    .and_then(|v| v.as_str())
+                                    .map(String::from);
+
+                                if let Some(arch) = &architecture {
+                                    context_length = obj
+                                        .get(&format!("{}.context_length", arch))
+                                        .and_then(|v| v.as_u64());
+                                }
+                            }
+                        }
+                        _ => {
+                            let _: de::IgnoredAny = map.next_value()?;
+                        }
+                    }
+                }
+
+                Ok(ModelShow {
+                    capabilities,
+                    context_length,
+                    architecture,
+                })
+            }
+        }
+
+        deserializer.deserialize_map(ModelShowVisitor)
+    }
 }
 
 impl ModelShow {
@@ -470,6 +534,9 @@ mod tests {
         assert!(result.supports_tools());
         assert!(result.capabilities.contains(&"tools".to_string()));
         assert!(result.capabilities.contains(&"completion".to_string()));
+
+        assert_eq!(result.architecture, Some("llama".to_string()));
+        assert_eq!(result.context_length, Some(131072));
     }
 
     #[test]