1use anyhow::{Context, Result};
2
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7pub use settings::KeepAlive;
8
9pub const OLLAMA_API_URL: &str = "http://localhost:11434";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
13pub struct Model {
14 pub name: String,
15 pub display_name: Option<String>,
16 pub max_tokens: u64,
17 pub keep_alive: Option<KeepAlive>,
18 pub supports_tools: Option<bool>,
19 pub supports_vision: Option<bool>,
20 pub supports_thinking: Option<bool>,
21}
22
23fn get_max_tokens(_name: &str) -> u64 {
24 const DEFAULT_TOKENS: u64 = 4096;
25 DEFAULT_TOKENS
26}
27
28impl Model {
29 pub fn new(
30 name: &str,
31 display_name: Option<&str>,
32 max_tokens: Option<u64>,
33 supports_tools: Option<bool>,
34 supports_vision: Option<bool>,
35 supports_thinking: Option<bool>,
36 ) -> Self {
37 Self {
38 name: name.to_owned(),
39 display_name: display_name
40 .map(ToString::to_string)
41 .or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
42 max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
43 keep_alive: Some(KeepAlive::indefinite()),
44 supports_tools,
45 supports_vision,
46 supports_thinking,
47 }
48 }
49
50 pub fn id(&self) -> &str {
51 &self.name
52 }
53
54 pub fn display_name(&self) -> &str {
55 self.display_name.as_ref().unwrap_or(&self.name)
56 }
57
58 pub fn max_token_count(&self) -> u64 {
59 self.max_tokens
60 }
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64#[serde(tag = "role", rename_all = "lowercase")]
65pub enum ChatMessage {
66 Assistant {
67 content: String,
68 tool_calls: Option<Vec<OllamaToolCall>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 images: Option<Vec<String>>,
71 thinking: Option<String>,
72 },
73 User {
74 content: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 images: Option<Vec<String>>,
77 },
78 System {
79 content: String,
80 },
81 Tool {
82 tool_name: String,
83 content: String,
84 },
85}
86
87#[derive(Serialize, Deserialize, Debug)]
88pub struct OllamaToolCall {
89 pub id: String,
90 pub function: OllamaFunctionCall,
91}
92
93#[derive(Serialize, Deserialize, Debug)]
94pub struct OllamaFunctionCall {
95 pub name: String,
96 pub arguments: Value,
97}
98
99#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
100pub struct OllamaFunctionTool {
101 pub name: String,
102 pub description: Option<String>,
103 pub parameters: Option<Value>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107#[serde(tag = "type", rename_all = "lowercase")]
108pub enum OllamaTool {
109 Function { function: OllamaFunctionTool },
110}
111
112#[derive(Serialize, Debug)]
113pub struct ChatRequest {
114 pub model: String,
115 pub messages: Vec<ChatMessage>,
116 pub stream: bool,
117 pub keep_alive: KeepAlive,
118 pub options: Option<ChatOptions>,
119 pub tools: Vec<OllamaTool>,
120 pub think: Option<bool>,
121}
122
123// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
124#[derive(Serialize, Default, Debug)]
125pub struct ChatOptions {
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub num_ctx: Option<u64>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub num_predict: Option<isize>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 pub stop: Option<Vec<String>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub temperature: Option<f32>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub top_p: Option<f32>,
136}
137
138#[derive(Deserialize, Debug)]
139pub struct ChatResponseDelta {
140 pub model: String,
141 pub created_at: String,
142 pub message: ChatMessage,
143 pub done_reason: Option<String>,
144 pub done: bool,
145 pub prompt_eval_count: Option<u64>,
146 pub eval_count: Option<u64>,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct LocalModelsResponse {
151 pub models: Vec<LocalModelListing>,
152}
153
154#[derive(Serialize, Deserialize)]
155pub struct LocalModelListing {
156 pub name: String,
157 pub modified_at: String,
158 pub size: u64,
159 pub digest: String,
160 pub details: ModelDetails,
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct LocalModel {
165 pub modelfile: String,
166 pub parameters: String,
167 pub template: String,
168 pub details: ModelDetails,
169}
170
171#[derive(Serialize, Deserialize)]
172pub struct ModelDetails {
173 pub format: String,
174 pub family: String,
175 pub families: Option<Vec<String>>,
176 pub parameter_size: String,
177 pub quantization_level: String,
178}
179
180#[derive(Debug)]
181pub struct ModelShow {
182 pub capabilities: Vec<String>,
183 pub context_length: Option<u64>,
184 pub architecture: Option<String>,
185}
186
187impl<'de> Deserialize<'de> for ModelShow {
188 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
189 where
190 D: serde::Deserializer<'de>,
191 {
192 use serde::de::{self, MapAccess, Visitor};
193 use std::fmt;
194
195 struct ModelShowVisitor;
196
197 impl<'de> Visitor<'de> for ModelShowVisitor {
198 type Value = ModelShow;
199
200 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
201 formatter.write_str("a ModelShow object")
202 }
203
204 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
205 where
206 A: MapAccess<'de>,
207 {
208 let mut capabilities: Vec<String> = Vec::new();
209 let mut architecture: Option<String> = None;
210 let mut context_length: Option<u64> = None;
211 let mut num_ctx: Option<u64> = None;
212
213 while let Some(key) = map.next_key::<String>()? {
214 match key.as_str() {
215 "capabilities" => {
216 capabilities = map.next_value()?;
217 }
218 "parameters" => {
219 let params_str: String = map.next_value()?;
220 for line in params_str.lines() {
221 if let Some(start) = line.find("num_ctx") {
222 let value_part = &line[start + 7..];
223 if let Ok(value) = value_part.trim().parse::<u64>() {
224 num_ctx = Some(value);
225 break;
226 }
227 }
228 }
229 }
230 "model_info" => {
231 let model_info: Value = map.next_value()?;
232 if let Value::Object(obj) = model_info {
233 architecture = obj
234 .get("general.architecture")
235 .and_then(|v| v.as_str())
236 .map(String::from);
237
238 if let Some(arch) = &architecture {
239 context_length = obj
240 .get(&format!("{}.context_length", arch))
241 .and_then(|v| v.as_u64());
242 }
243 }
244 }
245 _ => {
246 let _: de::IgnoredAny = map.next_value()?;
247 }
248 }
249 }
250
251 let context_length = num_ctx.or(context_length);
252 Ok(ModelShow {
253 capabilities,
254 context_length,
255 architecture,
256 })
257 }
258 }
259
260 deserializer.deserialize_map(ModelShowVisitor)
261 }
262}
263
264impl ModelShow {
265 pub fn supports_tools(&self) -> bool {
266 // .contains expects &String, which would require an additional allocation
267 self.capabilities.iter().any(|v| v == "tools")
268 }
269
270 pub fn supports_vision(&self) -> bool {
271 self.capabilities.iter().any(|v| v == "vision")
272 }
273
274 pub fn supports_thinking(&self) -> bool {
275 self.capabilities.iter().any(|v| v == "thinking")
276 }
277}
278
279pub async fn stream_chat_completion(
280 client: &dyn HttpClient,
281 api_url: &str,
282 api_key: Option<&str>,
283 request: ChatRequest,
284) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
285 let uri = format!("{api_url}/api/chat");
286 let request = HttpRequest::builder()
287 .method(Method::POST)
288 .uri(uri)
289 .header("Content-Type", "application/json")
290 .when_some(api_key, |builder, api_key| {
291 builder.header("Authorization", format!("Bearer {api_key}"))
292 })
293 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
294
295 let mut response = client.send(request).await?;
296 if response.status().is_success() {
297 let reader = BufReader::new(response.into_body());
298
299 Ok(reader
300 .lines()
301 .map(|line| match line {
302 Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
303 Err(e) => Err(e.into()),
304 })
305 .boxed())
306 } else {
307 let mut body = String::new();
308 response.body_mut().read_to_string(&mut body).await?;
309 anyhow::bail!(
310 "Failed to connect to Ollama API: {} {}",
311 response.status(),
312 body,
313 );
314 }
315}
316
317pub async fn get_models(
318 client: &dyn HttpClient,
319 api_url: &str,
320 api_key: Option<&str>,
321) -> Result<Vec<LocalModelListing>> {
322 let uri = format!("{api_url}/api/tags");
323 let request = HttpRequest::builder()
324 .method(Method::GET)
325 .uri(uri)
326 .header("Accept", "application/json")
327 .when_some(api_key, |builder, api_key| {
328 builder.header("Authorization", format!("Bearer {api_key}"))
329 })
330 .body(AsyncBody::default())?;
331
332 let mut response = client.send(request).await?;
333
334 let mut body = String::new();
335 response.body_mut().read_to_string(&mut body).await?;
336
337 anyhow::ensure!(
338 response.status().is_success(),
339 "Failed to connect to Ollama API: {} {}",
340 response.status(),
341 body,
342 );
343 let response: LocalModelsResponse =
344 serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
345 Ok(response.models)
346}
347
348/// Fetch details of a model, used to determine model capabilities
349pub async fn show_model(
350 client: &dyn HttpClient,
351 api_url: &str,
352 api_key: Option<&str>,
353 model: &str,
354) -> Result<ModelShow> {
355 let uri = format!("{api_url}/api/show");
356 let request = HttpRequest::builder()
357 .method(Method::POST)
358 .uri(uri)
359 .header("Content-Type", "application/json")
360 .when_some(api_key, |builder, api_key| {
361 builder.header("Authorization", format!("Bearer {api_key}"))
362 })
363 .body(AsyncBody::from(
364 serde_json::json!({ "model": model }).to_string(),
365 ))?;
366
367 let mut response = client.send(request).await?;
368 let mut body = String::new();
369 response.body_mut().read_to_string(&mut body).await?;
370
371 anyhow::ensure!(
372 response.status().is_success(),
373 "Failed to connect to Ollama API: {} {}",
374 response.status(),
375 body,
376 );
377 let details: ModelShow = serde_json::from_str(body.as_str())?;
378 Ok(details)
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn parse_completion() {
387 let response = serde_json::json!({
388 "model": "llama3.2",
389 "created_at": "2023-12-12T14:13:43.416799Z",
390 "message": {
391 "role": "assistant",
392 "content": "Hello! How are you today?"
393 },
394 "done": true,
395 "total_duration": 5191566416u64,
396 "load_duration": 2154458,
397 "prompt_eval_count": 26,
398 "prompt_eval_duration": 383809000,
399 "eval_count": 298,
400 "eval_duration": 4799921000u64
401 });
402 let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
403 }
404
405 #[test]
406 fn parse_streaming_completion() {
407 let partial = serde_json::json!({
408 "model": "llama3.2",
409 "created_at": "2023-08-04T08:52:19.385406455-07:00",
410 "message": {
411 "role": "assistant",
412 "content": "The",
413 "images": null
414 },
415 "done": false
416 });
417
418 let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
419
420 let last = serde_json::json!({
421 "model": "llama3.2",
422 "created_at": "2023-08-04T19:22:45.499127Z",
423 "message": {
424 "role": "assistant",
425 "content": ""
426 },
427 "done": true,
428 "total_duration": 4883583458u64,
429 "load_duration": 1334875,
430 "prompt_eval_count": 26,
431 "prompt_eval_duration": 342546000,
432 "eval_count": 282,
433 "eval_duration": 4535599000u64
434 });
435
436 let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
437 }
438
439 #[test]
440 fn parse_tool_call() {
441 let response = serde_json::json!({
442 "model": "llama3.2:3b",
443 "created_at": "2025-04-28T20:02:02.140489Z",
444 "message": {
445 "role": "assistant",
446 "content": "",
447 "tool_calls": [
448 {
449 "id": "call_llama3.2:3b_145155",
450 "function": {
451 "name": "weather",
452 "arguments": {
453 "city": "london",
454 }
455 }
456 }
457 ]
458 },
459 "done_reason": "stop",
460 "done": true,
461 "total_duration": 2758629166u64,
462 "load_duration": 1770059875,
463 "prompt_eval_count": 147,
464 "prompt_eval_duration": 684637583,
465 "eval_count": 16,
466 "eval_duration": 302561917,
467 });
468
469 let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
470 match result.message {
471 ChatMessage::Assistant {
472 content,
473 tool_calls,
474 images: _,
475 thinking,
476 } => {
477 assert!(content.is_empty());
478 assert!(tool_calls.is_some_and(|v| !v.is_empty()));
479 assert!(thinking.is_none());
480 }
481 _ => panic!("Deserialized wrong role"),
482 }
483 }
484
485 #[test]
486 fn parse_show_model() {
487 let response = serde_json::json!({
488 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
489 "details": {
490 "parent_model": "",
491 "format": "gguf",
492 "family": "llama",
493 "families": ["llama"],
494 "parameter_size": "3.2B",
495 "quantization_level": "Q4_K_M"
496 },
497 "model_info": {
498 "general.architecture": "llama",
499 "general.basename": "Llama-3.2",
500 "general.file_type": 15,
501 "general.finetune": "Instruct",
502 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
503 "general.parameter_count": 3212749888u64,
504 "general.quantization_version": 2,
505 "general.size_label": "3B",
506 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
507 "general.type": "model",
508 "llama.attention.head_count": 24,
509 "llama.attention.head_count_kv": 8,
510 "llama.attention.key_length": 128,
511 "llama.attention.layer_norm_rms_epsilon": 0.00001,
512 "llama.attention.value_length": 128,
513 "llama.block_count": 28,
514 "llama.context_length": 131072,
515 "llama.embedding_length": 3072,
516 "llama.feed_forward_length": 8192,
517 "llama.rope.dimension_count": 128,
518 "llama.rope.freq_base": 500000,
519 "llama.vocab_size": 128256,
520 "tokenizer.ggml.bos_token_id": 128000,
521 "tokenizer.ggml.eos_token_id": 128009,
522 "tokenizer.ggml.merges": null,
523 "tokenizer.ggml.model": "gpt2",
524 "tokenizer.ggml.pre": "llama-bpe",
525 "tokenizer.ggml.token_type": null,
526 "tokenizer.ggml.tokens": null
527 },
528 "tensors": [
529 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
530 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
531 ],
532 "capabilities": ["completion", "tools"],
533 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
534 });
535
536 let result: ModelShow = serde_json::from_value(response).unwrap();
537 assert!(result.supports_tools());
538 assert!(result.capabilities.contains(&"tools".to_string()));
539 assert!(result.capabilities.contains(&"completion".to_string()));
540
541 assert_eq!(result.architecture, Some("llama".to_string()));
542 assert_eq!(result.context_length, Some(131072));
543 }
544
545 #[test]
546 fn parse_show_model_with_num_ctx_preference() {
547 let response = serde_json::json!({
548 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
549 "parameters": "num_ctx 32768\npresence_penalty 1.5\ntemperature 1\ntop_k 20\ntop_p 0.95",
550 "details": {
551 "parent_model": "",
552 "format": "gguf",
553 "family": "llama",
554 "families": ["llama"],
555 "parameter_size": "3.2B",
556 "quantization_level": "Q4_K_M"
557 },
558 "model_info": {
559 "general.architecture": "llama",
560 "general.basename": "Llama-3.2",
561 "general.file_type": 15,
562 "general.finetune": "Instruct",
563 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
564 "general.parameter_count": 3212749888u64,
565 "general.quantization_version": 2,
566 "general.size_label": "3B",
567 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
568 "general.type": "model",
569 "llama.attention.head_count": 24,
570 "llama.attention.head_count_kv": 8,
571 "llama.attention.key_length": 128,
572 "llama.attention.layer_norm_rms_epsilon": 0.00001,
573 "llama.attention.value_length": 128,
574 "llama.block_count": 28,
575 "llama.context_length": 131072,
576 "llama.embedding_length": 3072,
577 "llama.feed_forward_length": 8192,
578 "llama.rope.dimension_count": 128,
579 "llama.rope.freq_base": 500000,
580 "llama.vocab_size": 128256,
581 "tokenizer.ggml.bos_token_id": 128000,
582 "tokenizer.ggml.eos_token_id": 128009,
583 "tokenizer.ggml.merges": null,
584 "tokenizer.ggml.model": "gpt2",
585 "tokenizer.ggml.pre": "llama-bpe",
586 "tokenizer.ggml.token_type": null,
587 "tokenizer.ggml.tokens": null
588 },
589 "tensors": [
590 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
591 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
592 ],
593 "capabilities": ["completion", "tools"],
594 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
595 });
596
597 let result: ModelShow = serde_json::from_value(response).unwrap();
598
599 assert_eq!(result.context_length, Some(32768));
600 }
601
602 #[test]
603 fn parse_show_model_without_num_ctx_in_parameters_fallback() {
604 let response = serde_json::json!({
605 "license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
606 "parameters": "presence_penalty 1.5\ntemperature 1\ntop_k 20\ntop_p 0.95",
607 "details": {
608 "parent_model": "",
609 "format": "gguf",
610 "family": "llama",
611 "families": ["llama"],
612 "parameter_size": "3.2B",
613 "quantization_level": "Q4_K_M"
614 },
615 "model_info": {
616 "general.architecture": "llama",
617 "general.basename": "Llama-3.2",
618 "general.file_type": 15,
619 "general.finetune": "Instruct",
620 "general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
621 "general.parameter_count": 3212749888u64,
622 "general.quantization_version": 2,
623 "general.size_label": "3B",
624 "general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
625 "general.type": "model",
626 "llama.attention.head_count": 24,
627 "llama.attention.head_count_kv": 8,
628 "llama.attention.key_length": 128,
629 "llama.attention.layer_norm_rms_epsilon": 0.00001,
630 "llama.attention.value_length": 128,
631 "llama.block_count": 28,
632 "llama.context_length": 131072,
633 "llama.embedding_length": 3072,
634 "llama.feed_forward_length": 8192,
635 "llama.rope.dimension_count": 128,
636 "llama.rope.freq_base": 500000,
637 "llama.vocab_size": 128256,
638 "tokenizer.ggml.bos_token_id": 128000,
639 "tokenizer.ggml.eos_token_id": 128009,
640 "tokenizer.ggml.merges": null,
641 "tokenizer.ggml.model": "gpt2",
642 "tokenizer.ggml.pre": "llama-bpe",
643 "tokenizer.ggml.token_type": null,
644 "tokenizer.ggml.tokens": null
645 },
646 "tensors": [
647 { "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
648 { "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
649 ],
650 "capabilities": ["completion", "tools"],
651 "modified_at": "2025-04-29T21:24:41.445877632+03:00"
652 });
653
654 let result: ModelShow = serde_json::from_value(response).unwrap();
655
656 assert_eq!(result.context_length, Some(131072));
657 }
658
659 #[test]
660 fn serialize_chat_request_with_images() {
661 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
662
663 let request = ChatRequest {
664 model: "llava".to_string(),
665 messages: vec![ChatMessage::User {
666 content: "What do you see in this image?".to_string(),
667 images: Some(vec![base64_image.to_string()]),
668 }],
669 stream: false,
670 keep_alive: KeepAlive::default(),
671 options: None,
672 think: None,
673 tools: vec![],
674 };
675
676 let serialized = serde_json::to_string(&request).unwrap();
677 assert!(serialized.contains("images"));
678 assert!(serialized.contains(base64_image));
679 }
680
681 #[test]
682 fn serialize_chat_request_without_images() {
683 let request = ChatRequest {
684 model: "llama3.2".to_string(),
685 messages: vec![ChatMessage::User {
686 content: "Hello, world!".to_string(),
687 images: None,
688 }],
689 stream: false,
690 keep_alive: KeepAlive::default(),
691 options: None,
692 think: None,
693 tools: vec![],
694 };
695
696 let serialized = serde_json::to_string(&request).unwrap();
697 assert!(!serialized.contains("images"));
698 }
699
700 #[test]
701 fn test_json_format_with_images() {
702 let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==";
703
704 let request = ChatRequest {
705 model: "llava".to_string(),
706 messages: vec![ChatMessage::User {
707 content: "What do you see?".to_string(),
708 images: Some(vec![base64_image.to_string()]),
709 }],
710 stream: false,
711 keep_alive: KeepAlive::default(),
712 options: None,
713 think: None,
714 tools: vec![],
715 };
716
717 let serialized = serde_json::to_string(&request).unwrap();
718
719 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
720 let message_images = parsed["messages"][0]["images"].as_array().unwrap();
721 assert_eq!(message_images.len(), 1);
722 assert_eq!(message_images[0].as_str().unwrap(), base64_image);
723 }
724
725 #[test]
726 fn test_chat_options_serialization() {
727 // When stop is None, it should not appear in JSON at all
728 // This allows Ollama to use the model's default stop tokens
729 let options_no_stop = ChatOptions {
730 num_ctx: Some(4096),
731 stop: None,
732 temperature: Some(0.7),
733 ..Default::default()
734 };
735 let serialized = serde_json::to_string(&options_no_stop).unwrap();
736 assert!(
737 !serialized.contains("stop"),
738 "stop should not be in JSON when None"
739 );
740 assert!(serialized.contains("num_ctx"));
741 assert!(serialized.contains("temperature"));
742
743 // When stop has values, they should be serialized
744 let options_with_stop = ChatOptions {
745 stop: Some(vec!["<|eot_id|>".to_string()]),
746 ..Default::default()
747 };
748 let serialized = serde_json::to_string(&options_with_stop).unwrap();
749 assert!(serialized.contains("stop"));
750 assert!(serialized.contains("<|eot_id|>"));
751
752 // All None options should result in empty object
753 let options_all_none = ChatOptions::default();
754 let serialized = serde_json::to_string(&options_all_none).unwrap();
755 assert_eq!(serialized, "{}");
756 }
757
758 #[test]
759 fn test_chat_request_with_stop_tokens() {
760 let request = ChatRequest {
761 model: "rnj-1:8b".to_string(),
762 messages: vec![ChatMessage::User {
763 content: "Hello".to_string(),
764 images: None,
765 }],
766 stream: true,
767 keep_alive: KeepAlive::default(),
768 options: Some(ChatOptions {
769 stop: Some(vec!["<|eot_id|>".to_string(), "<|end|>".to_string()]),
770 ..Default::default()
771 }),
772 think: None,
773 tools: vec![],
774 };
775
776 let serialized = serde_json::to_string(&request).unwrap();
777 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
778
779 let stop = parsed["options"]["stop"].as_array().unwrap();
780 assert_eq!(stop.len(), 2);
781 assert_eq!(stop[0].as_str().unwrap(), "<|eot_id|>");
782 assert_eq!(stop[1].as_str().unwrap(), "<|end|>");
783 }
784
785 #[test]
786 fn test_chat_request_without_stop_tokens_omits_field() {
787 // This tests the fix for issue #47798
788 // When no stop tokens are provided, the field should be omitted
789 // so Ollama uses the model's default stop tokens from Modelfile
790 let request = ChatRequest {
791 model: "rnj-1:8b".to_string(),
792 messages: vec![ChatMessage::User {
793 content: "Hello".to_string(),
794 images: None,
795 }],
796 stream: true,
797 keep_alive: KeepAlive::default(),
798 options: Some(ChatOptions {
799 num_ctx: Some(4096),
800 stop: None, // No stop tokens - should be omitted from JSON
801 ..Default::default()
802 }),
803 think: None,
804 tools: vec![],
805 };
806
807 let serialized = serde_json::to_string(&request).unwrap();
808
809 // The key check: "stop" should not appear in the serialized JSON
810 assert!(
811 !serialized.contains("\"stop\""),
812 "stop field should be omitted when None, got: {}",
813 serialized
814 );
815 }
816}