ollama.rs

  1use anyhow::{Context, Result};
  2
  3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
  5use serde::{Deserialize, Serialize};
  6use serde_json::Value;
  7pub use settings::KeepAlive;
  8
  9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 13pub struct Model {
 14    pub name: String,
 15    pub display_name: Option<String>,
 16    pub max_tokens: u64,
 17    pub keep_alive: Option<KeepAlive>,
 18    pub supports_tools: Option<bool>,
 19    pub supports_vision: Option<bool>,
 20    pub supports_thinking: Option<bool>,
 21}
 22
 23fn get_max_tokens(_name: &str) -> u64 {
 24    const DEFAULT_TOKENS: u64 = 4096;
 25    DEFAULT_TOKENS
 26}
 27
 28impl Model {
 29    pub fn new(
 30        name: &str,
 31        display_name: Option<&str>,
 32        max_tokens: Option<u64>,
 33        supports_tools: Option<bool>,
 34        supports_vision: Option<bool>,
 35        supports_thinking: Option<bool>,
 36    ) -> Self {
 37        Self {
 38            name: name.to_owned(),
 39            display_name: display_name
 40                .map(ToString::to_string)
 41                .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
 42            max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
 43            keep_alive: Some(KeepAlive::indefinite()),
 44            supports_tools,
 45            supports_vision,
 46            supports_thinking,
 47        }
 48    }
 49
 50    pub fn id(&self) -> &str {
 51        &self.name
 52    }
 53
 54    pub fn display_name(&self) -> &str {
 55        self.display_name.as_ref().unwrap_or(&self.name)
 56    }
 57
 58    pub fn max_token_count(&self) -> u64 {
 59        self.max_tokens
 60    }
 61}
 62
 63#[derive(Serialize, Deserialize, Debug)]
 64#[serde(tag = "role", rename_all = "lowercase")]
 65pub enum ChatMessage {
 66    Assistant {
 67        content: String,
 68        tool_calls: Option<Vec<OllamaToolCall>>,
 69        #[serde(skip_serializing_if = "Option::is_none")]
 70        images: Option<Vec<String>>,
 71        thinking: Option<String>,
 72    },
 73    User {
 74        content: String,
 75        #[serde(skip_serializing_if = "Option::is_none")]
 76        images: Option<Vec<String>>,
 77    },
 78    System {
 79        content: String,
 80    },
 81    Tool {
 82        tool_name: String,
 83        content: String,
 84    },
 85}
 86
 87#[derive(Serialize, Deserialize, Debug)]
 88pub struct OllamaToolCall {
 89    pub id: String,
 90    pub function: OllamaFunctionCall,
 91}
 92
 93#[derive(Serialize, Deserialize, Debug)]
 94pub struct OllamaFunctionCall {
 95    pub name: String,
 96    pub arguments: Value,
 97}
 98
 99#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
100pub struct OllamaFunctionTool {
101    pub name: String,
102    pub description: Option<String>,
103    pub parameters: Option<Value>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107#[serde(tag = "type", rename_all = "lowercase")]
108pub enum OllamaTool {
109    Function { function: OllamaFunctionTool },
110}
111
112#[derive(Serialize, Debug)]
113pub struct ChatRequest {
114    pub model: String,
115    pub messages: Vec<ChatMessage>,
116    pub stream: bool,
117    pub keep_alive: KeepAlive,
118    pub options: Option<ChatOptions>,
119    pub tools: Vec<OllamaTool>,
120    pub think: Option<bool>,
121}
122
123// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
124#[derive(Serialize, Default, Debug)]
125pub struct ChatOptions {
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub num_ctx: Option<u64>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub num_predict: Option<isize>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub stop: Option<Vec<String>>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub temperature: Option<f32>,
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub top_p: Option<f32>,
136}
137
138#[derive(Deserialize, Debug)]
139pub struct ChatResponseDelta {
140    pub model: String,
141    pub created_at: String,
142    pub message: ChatMessage,
143    pub done_reason: Option<String>,
144    pub done: bool,
145    pub prompt_eval_count: Option<u64>,
146    pub eval_count: Option<u64>,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct LocalModelsResponse {
151    pub models: Vec<LocalModelListing>,
152}
153
154#[derive(Serialize, Deserialize)]
155pub struct LocalModelListing {
156    pub name: String,
157    pub modified_at: String,
158    pub size: u64,
159    pub digest: String,
160    pub details: ModelDetails,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct LocalModel {
165    pub modelfile: String,
166    pub parameters: String,
167    pub template: String,
168    pub details: ModelDetails,
169}
170
171#[derive(Serialize, Deserialize)]
172pub struct ModelDetails {
173    pub format: String,
174    pub family: String,
175    pub families: Option<Vec<String>>,
176    pub parameter_size: String,
177    pub quantization_level: String,
178}
179
180#[derive(Debug)]
181pub struct ModelShow {
182    pub capabilities: Vec<String>,
183    pub context_length: Option<u64>,
184    pub architecture: Option<String>,
185}
186
187impl<'de> Deserialize<'de> for ModelShow {
188    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
189    where
190        D: serde::Deserializer<'de>,
191    {
192        use serde::de::{self, MapAccess, Visitor};
193        use std::fmt;
194
195        struct ModelShowVisitor;
196
197        impl<'de> Visitor<'de> for ModelShowVisitor {
198            type Value = ModelShow;
199
200            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
201                formatter.write_str("a ModelShow object")
202            }
203
204            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
205            where
206                A: MapAccess<'de>,
207            {
208                let mut capabilities: Vec<String> = Vec::new();
209                let mut architecture: Option<String> = None;
210                let mut context_length: Option<u64> = None;
211                let mut num_ctx: Option<u64> = None;
212
213                while let Some(key) = map.next_key::<String>()? {
214                    match key.as_str() {
215                        "capabilities" => {
216                            capabilities = map.next_value()?;
217                        }
218                        "parameters" => {
219                            let params_str: String = map.next_value()?;
220                            for line in params_str.lines() {
221                                if let Some(start) = line.find("num_ctx") {
222                                    let value_part = &line[start + 7..];
223                                    if let Ok(value) = value_part.trim().parse::<u64>() {
224                                        num_ctx = Some(value);
225                                        break;
226                                    }
227                                }
228                            }
229                        }
230                        "model_info" => {
231                            let model_info: Value = map.next_value()?;
232                            if let Value::Object(obj) = model_info {
233                                architecture = obj
234                                    .get("general.architecture")
235                                    .and_then(|v| v.as_str())
236                                    .map(String::from);
237
238                                if let Some(arch) = &architecture {
239                                    context_length = obj
240                                        .get(&format!("{}.context_length", arch))
241                                        .and_then(|v| v.as_u64());
242                                }
243                            }
244                        }
245                        _ => {
246                            let _: de::IgnoredAny = map.next_value()?;
247                        }
248                    }
249                }
250
251                let context_length = num_ctx.or(context_length);
252                Ok(ModelShow {
253                    capabilities,
254                    context_length,
255                    architecture,
256                })
257            }
258        }
259
260        deserializer.deserialize_map(ModelShowVisitor)
261    }
262}
263
264impl ModelShow {
265    pub fn supports_tools(&self) -> bool {
266        // .contains expects &String, which would require an additional allocation
267        self.capabilities.iter().any(|v| v == "tools")
268    }
269
270    pub fn supports_vision(&self) -> bool {
271        self.capabilities.iter().any(|v| v == "vision")
272    }
273
274    pub fn supports_thinking(&self) -> bool {
275        self.capabilities.iter().any(|v| v == "thinking")
276    }
277}
278
279pub async fn stream_chat_completion(
280    client: &dyn HttpClient,
281    api_url: &str,
282    api_key: Option<&str>,
283    request: ChatRequest,
284) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
285    let uri = format!("{api_url}/api/chat");
286    let request = HttpRequest::builder()
287        .method(Method::POST)
288        .uri(uri)
289        .header("Content-Type", "application/json")
290        .when_some(api_key, |builder, api_key| {
291            builder.header("Authorization", format!("Bearer {api_key}"))
292        })
293        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
294
295    let mut response = client.send(request).await?;
296    if response.status().is_success() {
297        let reader = BufReader::new(response.into_body());
298
299        Ok(reader
300            .lines()
301            .map(|line| match line {
302                Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
303                Err(e) => Err(e.into()),
304            })
305            .boxed())
306    } else {
307        let mut body = String::new();
308        response.body_mut().read_to_string(&mut body).await?;
309        anyhow::bail!(
310            "Failed to connect to Ollama API: {} {}",
311            response.status(),
312            body,
313        );
314    }
315}
316
317pub async fn get_models(
318    client: &dyn HttpClient,
319    api_url: &str,
320    api_key: Option<&str>,
321) -> Result<Vec<LocalModelListing>> {
322    let uri = format!("{api_url}/api/tags");
323    let request = HttpRequest::builder()
324        .method(Method::GET)
325        .uri(uri)
326        .header("Accept", "application/json")
327        .when_some(api_key, |builder, api_key| {
328            builder.header("Authorization", format!("Bearer {api_key}"))
329        })
330        .body(AsyncBody::default())?;
331
332    let mut response = client.send(request).await?;
333
334    let mut body = String::new();
335    response.body_mut().read_to_string(&mut body).await?;
336
337    anyhow::ensure!(
338        response.status().is_success(),
339        "Failed to connect to Ollama API: {} {}",
340        response.status(),
341        body,
342    );
343    let response: LocalModelsResponse =
344        serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
345    Ok(response.models)
346}
347
348/// Fetch details of a model, used to determine model capabilities
349pub async fn show_model(
350    client: &dyn HttpClient,
351    api_url: &str,
352    api_key: Option<&str>,
353    model: &str,
354) -> Result<ModelShow> {
355    let uri = format!("{api_url}/api/show");
356    let request = HttpRequest::builder()
357        .method(Method::POST)
358        .uri(uri)
359        .header("Content-Type", "application/json")
360        .when_some(api_key, |builder, api_key| {
361            builder.header("Authorization", format!("Bearer {api_key}"))
362        })
363        .body(AsyncBody::from(
364            serde_json::json!({ "model": model }).to_string(),
365        ))?;
366
367    let mut response = client.send(request).await?;
368    let mut body = String::new();
369    response.body_mut().read_to_string(&mut body).await?;
370
371    anyhow::ensure!(
372        response.status().is_success(),
373        "Failed to connect to Ollama API: {} {}",
374        response.status(),
375        body,
376    );
377    let details: ModelShow = serde_json::from_str(body.as_str())?;
378    Ok(details)
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn parse_completion() {
387        let response = serde_json::json!({
388        "model": "llama3.2",
389        "created_at": "2023-12-12T14:13:43.416799Z",
390        "message": {
391            "role": "assistant",
392            "content": "Hello! How are you today?"
393        },
394        "done": true,
395        "total_duration": 5191566416u64,
396        "load_duration": 2154458,
397        "prompt_eval_count": 26,
398        "prompt_eval_duration": 383809000,
399        "eval_count": 298,
400        "eval_duration": 4799921000u64
401        });
402        let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
403    }
404
405    #[test]
406    fn parse_streaming_completion() {
407        let partial = serde_json::json!({
408        "model": "llama3.2",
409        "created_at": "2023-08-04T08:52:19.385406455-07:00",
410        "message": {
411            "role": "assistant",
412            "content": "The",
413            "images": null
414        },
415        "done": false
416        });
417
418        let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
419
420        let last = serde_json::json!({
421        "model": "llama3.2",
422        "created_at": "2023-08-04T19:22:45.499127Z",
423        "message": {
424            "role": "assistant",
425            "content": ""
426        },
427        "done": true,
428        "total_duration": 4883583458u64,
429        "load_duration": 1334875,
430        "prompt_eval_count": 26,
431        "prompt_eval_duration": 342546000,
432        "eval_count": 282,
433        "eval_duration": 4535599000u64
434        });
435
436        let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
437    }
438
439    #[test]
440    fn parse_tool_call() {
441        let response = serde_json::json!({
442            "model": "llama3.2:3b",
443            "created_at": "2025-04-28T20:02:02.140489Z",
444            "message": {
445                "role": "assistant",
446                "content": "",
447                "tool_calls": [
448                    {
449                        "id": "call_llama3.2:3b_145155",
450                        "function": {
451                            "name": "weather",
452                            "arguments": {
453                                "city": "london",
454                            }
455                        }
456                    }
457                ]
458            },
459            "done_reason": "stop",
460            "done": true,
461            "total_duration": 2758629166u64,
462            "load_duration": 1770059875,
463            "prompt_eval_count": 147,
464            "prompt_eval_duration": 684637583,
465            "eval_count": 16,
466            "eval_duration": 302561917,
467        });
468
469        let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
470        match result.message {
471            ChatMessage::Assistant {
472                content,
473                tool_calls,
474                images: _,
475                thinking,
476            } => {
477                assert!(content.is_empty());
478                assert!(tool_calls.is_some_and(|v| !v.is_empty()));
479                assert!(thinking.is_none());
480            }
481            _ => panic!("Deserialized wrong role"),
482        }
483    }
484
485    #[test]
486    fn parse_show_model() {
487        let response = serde_json::json!({
488            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
489            "details": {
490                "parent_model": "",
491                "format": "gguf",
492                "family": "llama",
493                "families": ["llama"],
494                "parameter_size": "3.2B",
495                "quantization_level": "Q4_K_M"
496            },
497            "model_info": {
498                "general.architecture": "llama",
499                "general.basename": "Llama-3.2",
500                "general.file_type": 15,
501                "general.finetune": "Instruct",
502                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
503                "general.parameter_count": 3212749888u64,
504                "general.quantization_version": 2,
505                "general.size_label": "3B",
506                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
507                "general.type": "model",
508                "llama.attention.head_count": 24,
509                "llama.attention.head_count_kv": 8,
510                "llama.attention.key_length": 128,
511                "llama.attention.layer_norm_rms_epsilon": 0.00001,
512                "llama.attention.value_length": 128,
513                "llama.block_count": 28,
514                "llama.context_length": 131072,
515                "llama.embedding_length": 3072,
516                "llama.feed_forward_length": 8192,
517                "llama.rope.dimension_count": 128,
518                "llama.rope.freq_base": 500000,
519                "llama.vocab_size": 128256,
520                "tokenizer.ggml.bos_token_id": 128000,
521                "tokenizer.ggml.eos_token_id": 128009,
522                "tokenizer.ggml.merges": null,
523                "tokenizer.ggml.model": "gpt2",
524                "tokenizer.ggml.pre": "llama-bpe",
525                "tokenizer.ggml.token_type": null,
526                "tokenizer.ggml.tokens": null
527            },
528            "tensors": [
529                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
530                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
531            ],
532            "capabilities": ["completion", "tools"],
533            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
534        });
535
536        let result: ModelShow = serde_json::from_value(response).unwrap();
537        assert!(result.supports_tools());
538        assert!(result.capabilities.contains(&"tools".to_string()));
539        assert!(result.capabilities.contains(&"completion".to_string()));
540
541        assert_eq!(result.architecture, Some("llama".to_string()));
542        assert_eq!(result.context_length, Some(131072));
543    }
544
545    #[test]
546    fn parse_show_model_with_num_ctx_preference() {
547        let response = serde_json::json!({
548            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
549            "parameters": "num_ctx                        32768\npresence_penalty               1.5\ntemperature                    1\ntop_k                          20\ntop_p                          0.95",
550            "details": {
551                "parent_model": "",
552                "format": "gguf",
553                "family": "llama",
554                "families": ["llama"],
555                "parameter_size": "3.2B",
556                "quantization_level": "Q4_K_M"
557            },
558            "model_info": {
559                "general.architecture": "llama",
560                "general.basename": "Llama-3.2",
561                "general.file_type": 15,
562                "general.finetune": "Instruct",
563                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
564                "general.parameter_count": 3212749888u64,
565                "general.quantization_version": 2,
566                "general.size_label": "3B",
567                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
568                "general.type": "model",
569                "llama.attention.head_count": 24,
570                "llama.attention.head_count_kv": 8,
571                "llama.attention.key_length": 128,
572                "llama.attention.layer_norm_rms_epsilon": 0.00001,
573                "llama.attention.value_length": 128,
574                "llama.block_count": 28,
575                "llama.context_length": 131072,
576                "llama.embedding_length": 3072,
577                "llama.feed_forward_length": 8192,
578                "llama.rope.dimension_count": 128,
579                "llama.rope.freq_base": 500000,
580                "llama.vocab_size": 128256,
581                "tokenizer.ggml.bos_token_id": 128000,
582                "tokenizer.ggml.eos_token_id": 128009,
583                "tokenizer.ggml.merges": null,
584                "tokenizer.ggml.model": "gpt2",
585                "tokenizer.ggml.pre": "llama-bpe",
586                "tokenizer.ggml.token_type": null,
587                "tokenizer.ggml.tokens": null
588            },
589            "tensors": [
590                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
591                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
592            ],
593            "capabilities": ["completion", "tools"],
594            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
595        });
596
597        let result: ModelShow = serde_json::from_value(response).unwrap();
598
599        assert_eq!(result.context_length, Some(32768));
600    }
601
602    #[test]
603    fn parse_show_model_without_num_ctx_in_parameters_fallback() {
604        let response = serde_json::json!({
605            "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
606            "parameters": "presence_penalty               1.5\ntemperature                    1\ntop_k                          20\ntop_p                          0.95",
607            "details": {
608                "parent_model": "",
609                "format": "gguf",
610                "family": "llama",
611                "families": ["llama"],
612                "parameter_size": "3.2B",
613                "quantization_level": "Q4_K_M"
614            },
615            "model_info": {
616                "general.architecture": "llama",
617                "general.basename": "Llama-3.2",
618                "general.file_type": 15,
619                "general.finetune": "Instruct",
620                "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
621                "general.parameter_count": 3212749888u64,
622                "general.quantization_version": 2,
623                "general.size_label": "3B",
624                "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
625                "general.type": "model",
626                "llama.attention.head_count": 24,
627                "llama.attention.head_count_kv": 8,
628                "llama.attention.key_length": 128,
629                "llama.attention.layer_norm_rms_epsilon": 0.00001,
630                "llama.attention.value_length": 128,
631                "llama.block_count": 28,
632                "llama.context_length": 131072,
633                "llama.embedding_length": 3072,
634                "llama.feed_forward_length": 8192,
635                "llama.rope.dimension_count": 128,
636                "llama.rope.freq_base": 500000,
637                "llama.vocab_size": 128256,
638                "tokenizer.ggml.bos_token_id": 128000,
639                "tokenizer.ggml.eos_token_id": 128009,
640                "tokenizer.ggml.merges": null,
641                "tokenizer.ggml.model": "gpt2",
642                "tokenizer.ggml.pre": "llama-bpe",
643                "tokenizer.ggml.token_type": null,
644                "tokenizer.ggml.tokens": null
645            },
646            "tensors": [
647                { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
648                { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
649            ],
650            "capabilities": ["completion", "tools"],
651            "modified_at": "2025-04-29T21:24:41.445877632+03:00"
652        });
653
654        let result: ModelShow = serde_json::from_value(response).unwrap();
655
656        assert_eq!(result.context_length, Some(131072));
657    }
658
659    #[test]
660    fn serialize_chat_request_with_images() {
661        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
662
663        let request = ChatRequest {
664            model: "llava".to_string(),
665            messages: vec![ChatMessage::User {
666                content: "What do you see in this image?".to_string(),
667                images: Some(vec![base64_image.to_string()]),
668            }],
669            stream: false,
670            keep_alive: KeepAlive::default(),
671            options: None,
672            think: None,
673            tools: vec![],
674        };
675
676        let serialized = serde_json::to_string(&request).unwrap();
677        assert!(serialized.contains("images"));
678        assert!(serialized.contains(base64_image));
679    }
680
681    #[test]
682    fn serialize_chat_request_without_images() {
683        let request = ChatRequest {
684            model: "llama3.2".to_string(),
685            messages: vec![ChatMessage::User {
686                content: "Hello, world!".to_string(),
687                images: None,
688            }],
689            stream: false,
690            keep_alive: KeepAlive::default(),
691            options: None,
692            think: None,
693            tools: vec![],
694        };
695
696        let serialized = serde_json::to_string(&request).unwrap();
697        assert!(!serialized.contains("images"));
698    }
699
700    #[test]
701    fn test_json_format_with_images() {
702        let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
703
704        let request = ChatRequest {
705            model: "llava".to_string(),
706            messages: vec![ChatMessage::User {
707                content: "What do you see?".to_string(),
708                images: Some(vec![base64_image.to_string()]),
709            }],
710            stream: false,
711            keep_alive: KeepAlive::default(),
712            options: None,
713            think: None,
714            tools: vec![],
715        };
716
717        let serialized = serde_json::to_string(&request).unwrap();
718
719        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
720        let message_images = parsed["messages"][0]["images"].as_array().unwrap();
721        assert_eq!(message_images.len(), 1);
722        assert_eq!(message_images[0].as_str().unwrap(), base64_image);
723    }
724
725    #[test]
726    fn test_chat_options_serialization() {
727        // When stop is None, it should not appear in JSON at all
728        // This allows Ollama to use the model's default stop tokens
729        let options_no_stop = ChatOptions {
730            num_ctx: Some(4096),
731            stop: None,
732            temperature: Some(0.7),
733            ..Default::default()
734        };
735        let serialized = serde_json::to_string(&options_no_stop).unwrap();
736        assert!(
737            !serialized.contains("stop"),
738            "stop should not be in JSON when None"
739        );
740        assert!(serialized.contains("num_ctx"));
741        assert!(serialized.contains("temperature"));
742
743        // When stop has values, they should be serialized
744        let options_with_stop = ChatOptions {
745            stop: Some(vec!["<|eot_id|>".to_string()]),
746            ..Default::default()
747        };
748        let serialized = serde_json::to_string(&options_with_stop).unwrap();
749        assert!(serialized.contains("stop"));
750        assert!(serialized.contains("<|eot_id|>"));
751
752        // All None options should result in empty object
753        let options_all_none = ChatOptions::default();
754        let serialized = serde_json::to_string(&options_all_none).unwrap();
755        assert_eq!(serialized, "{}");
756    }
757
758    #[test]
759    fn test_chat_request_with_stop_tokens() {
760        let request = ChatRequest {
761            model: "rnj-1:8b".to_string(),
762            messages: vec![ChatMessage::User {
763                content: "Hello".to_string(),
764                images: None,
765            }],
766            stream: true,
767            keep_alive: KeepAlive::default(),
768            options: Some(ChatOptions {
769                stop: Some(vec!["<|eot_id|>".to_string(), "<|end|>".to_string()]),
770                ..Default::default()
771            }),
772            think: None,
773            tools: vec![],
774        };
775
776        let serialized = serde_json::to_string(&request).unwrap();
777        let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
778
779        let stop = parsed["options"]["stop"].as_array().unwrap();
780        assert_eq!(stop.len(), 2);
781        assert_eq!(stop[0].as_str().unwrap(), "<|eot_id|>");
782        assert_eq!(stop[1].as_str().unwrap(), "<|end|>");
783    }
784
785    #[test]
786    fn test_chat_request_without_stop_tokens_omits_field() {
787        // This tests the fix for issue #47798
788        // When no stop tokens are provided, the field should be omitted
789        // so Ollama uses the model's default stop tokens from Modelfile
790        let request = ChatRequest {
791            model: "rnj-1:8b".to_string(),
792            messages: vec![ChatMessage::User {
793                content: "Hello".to_string(),
794                images: None,
795            }],
796            stream: true,
797            keep_alive: KeepAlive::default(),
798            options: Some(ChatOptions {
799                num_ctx: Some(4096),
800                stop: None, // No stop tokens - should be omitted from JSON
801                ..Default::default()
802            }),
803            think: None,
804            tools: vec![],
805        };
806
807        let serialized = serde_json::to_string(&request).unwrap();
808
809        // The key check: "stop" should not appear in the serialized JSON
810        assert!(
811            !serialized.contains("\"stop\""),
812            "stop field should be omitted when None, got: {}",
813            serialized
814        );
815    }
816}