@@ -119,16 +119,16 @@ impl State {
let api_key = api_key.clone();
async move {
let name = model.name.as_str();
- let capabilities =
+ let model =
show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name)
.await?;
let ollama_model = ollama::Model::new(
name,
None,
- None,
- Some(capabilities.supports_tools()),
- Some(capabilities.supports_vision()),
- Some(capabilities.supports_thinking()),
+ model.context_length,
+ Some(model.supports_tools()),
+ Some(model.supports_vision()),
+ Some(model.supports_thinking()),
);
Ok(ollama_model)
}
@@ -189,10 +189,74 @@ pub struct ModelDetails {
pub quantization_level: String,
}
-#[derive(Deserialize, Debug)]
+#[derive(Debug)]
pub struct ModelShow {
- #[serde(default)]
pub capabilities: Vec<String>,
+ pub context_length: Option<u64>,
+ pub architecture: Option<String>,
+}
+
+impl<'de> Deserialize<'de> for ModelShow {
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ use serde::de::{self, MapAccess, Visitor};
+ use std::fmt;
+
+ struct ModelShowVisitor;
+
+ impl<'de> Visitor<'de> for ModelShowVisitor {
+ type Value = ModelShow;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a ModelShow object")
+ }
+
+ fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
+ where
+ A: MapAccess<'de>,
+ {
+ let mut capabilities: Vec<String> = Vec::new();
+ let mut architecture: Option<String> = None;
+ let mut context_length: Option<u64> = None;
+
+ while let Some(key) = map.next_key::<String>()? {
+ match key.as_str() {
+ "capabilities" => {
+ capabilities = map.next_value()?;
+ }
+ "model_info" => {
+ let model_info: Value = map.next_value()?;
+ if let Value::Object(obj) = model_info {
+ architecture = obj
+ .get("general.architecture")
+ .and_then(|v| v.as_str())
+ .map(String::from);
+
+ if let Some(arch) = &architecture {
+ context_length = obj
+ .get(&format!("{}.context_length", arch))
+ .and_then(|v| v.as_u64());
+ }
+ }
+ }
+ _ => {
+ let _: de::IgnoredAny = map.next_value()?;
+ }
+ }
+ }
+
+ Ok(ModelShow {
+ capabilities,
+ context_length,
+ architecture,
+ })
+ }
+ }
+
+ deserializer.deserialize_map(ModelShowVisitor)
+ }
}
impl ModelShow {
@@ -470,6 +534,9 @@ mod tests {
assert!(result.supports_tools());
assert!(result.capabilities.contains(&"tools".to_string()));
assert!(result.capabilities.contains(&"completion".to_string()));
+
+ assert_eq!(result.architecture, Some("llama".to_string()));
+ assert_eq!(result.context_length, Some(131072));
}
#[test]