1use std::mem;
2
3use anyhow::{Result, anyhow, bail};
4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7pub use settings::ModelMode as GoogleModelMode;
8
9pub const API_URL: &str = "https://generativelanguage.googleapis.com";
10
11pub async fn stream_generate_content(
12 client: &dyn HttpClient,
13 api_url: &str,
14 api_key: &str,
15 mut request: GenerateContentRequest,
16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
17 let api_key = api_key.trim();
18 validate_generate_content_request(&request)?;
19
20 // The `model` field is emptied as it is provided as a path parameter.
21 let model_id = mem::take(&mut request.model.model_id);
22
23 let uri =
24 format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
25
26 let request_builder = HttpRequest::builder()
27 .method(Method::POST)
28 .uri(uri)
29 .header("Content-Type", "application/json");
30
31 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
32 let mut response = client.send(request).await?;
33 if response.status().is_success() {
34 let reader = BufReader::new(response.into_body());
35 Ok(reader
36 .lines()
37 .filter_map(|line| async move {
38 match line {
39 Ok(line) => {
40 if let Some(line) = line.strip_prefix("data: ") {
41 match serde_json::from_str(line) {
42 Ok(response) => Some(Ok(response)),
43 Err(error) => Some(Err(anyhow!(format!(
44 "Error parsing JSON: {error:?}\n{line:?}"
45 )))),
46 }
47 } else {
48 None
49 }
50 }
51 Err(error) => Some(Err(anyhow!(error))),
52 }
53 })
54 .boxed())
55 } else {
56 let mut text = String::new();
57 response.body_mut().read_to_string(&mut text).await?;
58 Err(anyhow!(
59 "error during streamGenerateContent, status code: {:?}, body: {}",
60 response.status(),
61 text
62 ))
63 }
64}
65
66pub async fn count_tokens(
67 client: &dyn HttpClient,
68 api_url: &str,
69 api_key: &str,
70 request: CountTokensRequest,
71) -> Result<CountTokensResponse> {
72 validate_generate_content_request(&request.generate_content_request)?;
73
74 let uri = format!(
75 "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
76 model_id = &request.generate_content_request.model.model_id,
77 );
78
79 let request = serde_json::to_string(&request)?;
80 let request_builder = HttpRequest::builder()
81 .method(Method::POST)
82 .uri(&uri)
83 .header("Content-Type", "application/json");
84 let http_request = request_builder.body(AsyncBody::from(request))?;
85
86 let mut response = client.send(http_request).await?;
87 let mut text = String::new();
88 response.body_mut().read_to_string(&mut text).await?;
89 anyhow::ensure!(
90 response.status().is_success(),
91 "error during countTokens, status code: {:?}, body: {}",
92 response.status(),
93 text
94 );
95 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
96}
97
98pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
99 if request.model.is_empty() {
100 bail!("Model must be specified");
101 }
102
103 if request.contents.is_empty() {
104 bail!("Request must contain at least one content item");
105 }
106
107 if let Some(user_content) = request
108 .contents
109 .iter()
110 .find(|content| content.role == Role::User)
111 && user_content.parts.is_empty()
112 {
113 bail!("User content must contain at least one part");
114 }
115
116 Ok(())
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub enum Task {
121 #[serde(rename = "generateContent")]
122 GenerateContent,
123 #[serde(rename = "streamGenerateContent")]
124 StreamGenerateContent,
125 #[serde(rename = "countTokens")]
126 CountTokens,
127 #[serde(rename = "embedContent")]
128 EmbedContent,
129 #[serde(rename = "batchEmbedContents")]
130 BatchEmbedContents,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct GenerateContentRequest {
136 #[serde(default, skip_serializing_if = "ModelName::is_empty")]
137 pub model: ModelName,
138 pub contents: Vec<Content>,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub system_instruction: Option<SystemInstruction>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub generation_config: Option<GenerationConfig>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub safety_settings: Option<Vec<SafetySetting>>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub tools: Option<Vec<Tool>>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub tool_config: Option<ToolConfig>,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct GenerateContentResponse {
154 #[serde(skip_serializing_if = "Option::is_none")]
155 pub candidates: Option<Vec<GenerateContentCandidate>>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub prompt_feedback: Option<PromptFeedback>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub usage_metadata: Option<UsageMetadata>,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct GenerateContentCandidate {
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub index: Option<usize>,
167 pub content: Content,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub finish_reason: Option<String>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub finish_message: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub safety_ratings: Option<Vec<SafetyRating>>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub citation_metadata: Option<CitationMetadata>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct Content {
181 #[serde(default)]
182 pub parts: Vec<Part>,
183 pub role: Role,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct SystemInstruction {
189 pub parts: Vec<Part>,
190}
191
192#[derive(Debug, PartialEq, Deserialize, Serialize)]
193#[serde(rename_all = "camelCase")]
194pub enum Role {
195 User,
196 Model,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum Part {
202 TextPart(TextPart),
203 InlineDataPart(InlineDataPart),
204 FunctionCallPart(FunctionCallPart),
205 FunctionResponsePart(FunctionResponsePart),
206 ThoughtPart(ThoughtPart),
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "camelCase")]
211pub struct TextPart {
212 pub text: String,
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct InlineDataPart {
218 pub inline_data: GenerativeContentBlob,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GenerativeContentBlob {
224 pub mime_type: String,
225 pub data: String,
226}
227
228#[derive(Debug, Serialize, Deserialize)]
229#[serde(rename_all = "camelCase")]
230pub struct FunctionCallPart {
231 pub function_call: FunctionCall,
232 /// Thought signature returned by the model for function calls.
233 /// Only present on the first function call in parallel call scenarios.
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub thought_signature: Option<String>,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239#[serde(rename_all = "camelCase")]
240pub struct FunctionResponsePart {
241 pub function_response: FunctionResponse,
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245#[serde(rename_all = "camelCase")]
246pub struct ThoughtPart {
247 pub thought: bool,
248 pub thought_signature: String,
249}
250
251#[derive(Debug, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct CitationSource {
254 #[serde(skip_serializing_if = "Option::is_none")]
255 pub start_index: Option<usize>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub end_index: Option<usize>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 pub uri: Option<String>,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub license: Option<String>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct CitationMetadata {
267 pub citation_sources: Vec<CitationSource>,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271#[serde(rename_all = "camelCase")]
272pub struct PromptFeedback {
273 #[serde(skip_serializing_if = "Option::is_none")]
274 pub block_reason: Option<String>,
275 pub safety_ratings: Option<Vec<SafetyRating>>,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 pub block_reason_message: Option<String>,
278}
279
280#[derive(Debug, Serialize, Deserialize, Default)]
281#[serde(rename_all = "camelCase")]
282pub struct UsageMetadata {
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub prompt_token_count: Option<u64>,
285 #[serde(skip_serializing_if = "Option::is_none")]
286 pub cached_content_token_count: Option<u64>,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 pub candidates_token_count: Option<u64>,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 pub tool_use_prompt_token_count: Option<u64>,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 pub thoughts_token_count: Option<u64>,
293 #[serde(skip_serializing_if = "Option::is_none")]
294 pub total_token_count: Option<u64>,
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298#[serde(rename_all = "camelCase")]
299pub struct ThinkingConfig {
300 pub thinking_budget: u32,
301}
302
303#[derive(Debug, Deserialize, Serialize)]
304#[serde(rename_all = "camelCase")]
305pub struct GenerationConfig {
306 #[serde(skip_serializing_if = "Option::is_none")]
307 pub candidate_count: Option<usize>,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub stop_sequences: Option<Vec<String>>,
310 #[serde(skip_serializing_if = "Option::is_none")]
311 pub max_output_tokens: Option<usize>,
312 #[serde(skip_serializing_if = "Option::is_none")]
313 pub temperature: Option<f64>,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 pub top_p: Option<f64>,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 pub top_k: Option<usize>,
318 #[serde(skip_serializing_if = "Option::is_none")]
319 pub thinking_config: Option<ThinkingConfig>,
320}
321
322#[derive(Debug, Serialize, Deserialize)]
323#[serde(rename_all = "camelCase")]
324pub struct SafetySetting {
325 pub category: HarmCategory,
326 pub threshold: HarmBlockThreshold,
327}
328
329#[derive(Debug, Serialize, Deserialize)]
330pub enum HarmCategory {
331 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
332 Unspecified,
333 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
334 Derogatory,
335 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
336 Toxicity,
337 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
338 Violence,
339 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
340 Sexual,
341 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
342 Medical,
343 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
344 Dangerous,
345 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
346 Harassment,
347 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
348 HateSpeech,
349 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
350 SexuallyExplicit,
351 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
352 DangerousContent,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
357pub enum HarmBlockThreshold {
358 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
359 Unspecified,
360 BlockLowAndAbove,
361 BlockMediumAndAbove,
362 BlockOnlyHigh,
363 BlockNone,
364}
365
366#[derive(Debug, Serialize, Deserialize)]
367#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
368pub enum HarmProbability {
369 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
370 Unspecified,
371 Negligible,
372 Low,
373 Medium,
374 High,
375}
376
377#[derive(Debug, Serialize, Deserialize)]
378#[serde(rename_all = "camelCase")]
379pub struct SafetyRating {
380 pub category: HarmCategory,
381 pub probability: HarmProbability,
382}
383
384#[derive(Debug, Serialize, Deserialize)]
385#[serde(rename_all = "camelCase")]
386pub struct CountTokensRequest {
387 pub generate_content_request: GenerateContentRequest,
388}
389
390#[derive(Debug, Serialize, Deserialize)]
391#[serde(rename_all = "camelCase")]
392pub struct CountTokensResponse {
393 pub total_tokens: u64,
394}
395
396#[derive(Debug, Serialize, Deserialize)]
397pub struct FunctionCall {
398 pub name: String,
399 pub args: serde_json::Value,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403pub struct FunctionResponse {
404 pub name: String,
405 pub response: serde_json::Value,
406}
407
408#[derive(Debug, Serialize, Deserialize)]
409#[serde(rename_all = "camelCase")]
410pub struct Tool {
411 pub function_declarations: Vec<FunctionDeclaration>,
412}
413
414#[derive(Debug, Serialize, Deserialize)]
415#[serde(rename_all = "camelCase")]
416pub struct ToolConfig {
417 pub function_calling_config: FunctionCallingConfig,
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421#[serde(rename_all = "camelCase")]
422pub struct FunctionCallingConfig {
423 pub mode: FunctionCallingMode,
424 #[serde(skip_serializing_if = "Option::is_none")]
425 pub allowed_function_names: Option<Vec<String>>,
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429#[serde(rename_all = "lowercase")]
430pub enum FunctionCallingMode {
431 Auto,
432 Any,
433 None,
434}
435
436#[derive(Debug, Serialize, Deserialize)]
437pub struct FunctionDeclaration {
438 pub name: String,
439 pub description: String,
440 pub parameters: serde_json::Value,
441}
442
443#[derive(Debug, Default)]
444pub struct ModelName {
445 pub model_id: String,
446}
447
448impl ModelName {
449 pub fn is_empty(&self) -> bool {
450 self.model_id.is_empty()
451 }
452}
453
454const MODEL_NAME_PREFIX: &str = "models/";
455
456impl Serialize for ModelName {
457 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
458 where
459 S: Serializer,
460 {
461 serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
462 }
463}
464
465impl<'de> Deserialize<'de> for ModelName {
466 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
467 where
468 D: Deserializer<'de>,
469 {
470 let string = String::deserialize(deserializer)?;
471 if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
472 Ok(Self {
473 model_id: id.to_string(),
474 })
475 } else {
476 Err(serde::de::Error::custom(format!(
477 "Expected model name to begin with {}, got: {}",
478 MODEL_NAME_PREFIX, string
479 )))
480 }
481 }
482}
483
484#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
485#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
486pub enum Model {
487 #[serde(rename = "gemini-1.5-pro")]
488 Gemini15Pro,
489 #[serde(rename = "gemini-1.5-flash-8b")]
490 Gemini15Flash8b,
491 #[serde(rename = "gemini-1.5-flash")]
492 Gemini15Flash,
493 #[serde(
494 rename = "gemini-2.0-flash-lite",
495 alias = "gemini-2.0-flash-lite-preview"
496 )]
497 Gemini20FlashLite,
498 #[serde(rename = "gemini-2.0-flash")]
499 Gemini20Flash,
500 #[serde(
501 rename = "gemini-2.5-flash-lite-preview",
502 alias = "gemini-2.5-flash-lite-preview-06-17"
503 )]
504 Gemini25FlashLitePreview,
505 #[serde(
506 rename = "gemini-2.5-flash",
507 alias = "gemini-2.0-flash-thinking-exp",
508 alias = "gemini-2.5-flash-preview-04-17",
509 alias = "gemini-2.5-flash-preview-05-20",
510 alias = "gemini-2.5-flash-preview-latest"
511 )]
512 #[default]
513 Gemini25Flash,
514 #[serde(
515 rename = "gemini-2.5-pro",
516 alias = "gemini-2.0-pro-exp",
517 alias = "gemini-2.5-pro-preview-latest",
518 alias = "gemini-2.5-pro-exp-03-25",
519 alias = "gemini-2.5-pro-preview-03-25",
520 alias = "gemini-2.5-pro-preview-05-06",
521 alias = "gemini-2.5-pro-preview-06-05"
522 )]
523 Gemini25Pro,
524 #[serde(rename = "gemini-3-pro-preview")]
525 Gemini3ProPreview,
526 #[serde(rename = "custom")]
527 Custom {
528 name: String,
529 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
530 display_name: Option<String>,
531 max_tokens: u64,
532 #[serde(default)]
533 mode: GoogleModelMode,
534 },
535}
536
537impl Model {
538 pub fn default_fast() -> Self {
539 Self::Gemini20FlashLite
540 }
541
542 pub fn id(&self) -> &str {
543 match self {
544 Self::Gemini15Pro => "gemini-1.5-pro",
545 Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
546 Self::Gemini15Flash => "gemini-1.5-flash",
547 Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
548 Self::Gemini20Flash => "gemini-2.0-flash",
549 Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview",
550 Self::Gemini25Flash => "gemini-2.5-flash",
551 Self::Gemini25Pro => "gemini-2.5-pro",
552 Self::Gemini3ProPreview => "gemini-3-pro-preview",
553 Self::Custom { name, .. } => name,
554 }
555 }
556 pub fn request_id(&self) -> &str {
557 match self {
558 Self::Gemini15Pro => "gemini-1.5-pro",
559 Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
560 Self::Gemini15Flash => "gemini-1.5-flash",
561 Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
562 Self::Gemini20Flash => "gemini-2.0-flash",
563 Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview-06-17",
564 Self::Gemini25Flash => "gemini-2.5-flash",
565 Self::Gemini25Pro => "gemini-2.5-pro",
566 Self::Gemini3ProPreview => "gemini-3-pro-preview",
567 Self::Custom { name, .. } => name,
568 }
569 }
570
571 pub fn display_name(&self) -> &str {
572 match self {
573 Self::Gemini15Pro => "Gemini 1.5 Pro",
574 Self::Gemini15Flash8b => "Gemini 1.5 Flash-8b",
575 Self::Gemini15Flash => "Gemini 1.5 Flash",
576 Self::Gemini20FlashLite => "Gemini 2.0 Flash-Lite",
577 Self::Gemini20Flash => "Gemini 2.0 Flash",
578 Self::Gemini25FlashLitePreview => "Gemini 2.5 Flash-Lite Preview",
579 Self::Gemini25Flash => "Gemini 2.5 Flash",
580 Self::Gemini25Pro => "Gemini 2.5 Pro",
581 Self::Gemini3ProPreview => "Gemini 3 Pro",
582 Self::Custom {
583 name, display_name, ..
584 } => display_name.as_ref().unwrap_or(name),
585 }
586 }
587
588 pub fn max_token_count(&self) -> u64 {
589 match self {
590 Self::Gemini15Pro => 2_097_152,
591 Self::Gemini15Flash8b => 1_048_576,
592 Self::Gemini15Flash => 1_048_576,
593 Self::Gemini20FlashLite => 1_048_576,
594 Self::Gemini20Flash => 1_048_576,
595 Self::Gemini25FlashLitePreview => 1_000_000,
596 Self::Gemini25Flash => 1_048_576,
597 Self::Gemini25Pro => 1_048_576,
598 Self::Gemini3ProPreview => 1_048_576,
599 Self::Custom { max_tokens, .. } => *max_tokens,
600 }
601 }
602
603 pub fn max_output_tokens(&self) -> Option<u64> {
604 match self {
605 Model::Gemini15Pro => Some(8_192),
606 Model::Gemini15Flash8b => Some(8_192),
607 Model::Gemini15Flash => Some(8_192),
608 Model::Gemini20FlashLite => Some(8_192),
609 Model::Gemini20Flash => Some(8_192),
610 Model::Gemini25FlashLitePreview => Some(64_000),
611 Model::Gemini25Flash => Some(65_536),
612 Model::Gemini25Pro => Some(65_536),
613 Model::Gemini3ProPreview => Some(65_536),
614 Model::Custom { .. } => None,
615 }
616 }
617
618 pub fn supports_tools(&self) -> bool {
619 true
620 }
621
622 pub fn supports_images(&self) -> bool {
623 true
624 }
625
626 pub fn mode(&self) -> GoogleModelMode {
627 match self {
628 Self::Gemini15Pro
629 | Self::Gemini15Flash8b
630 | Self::Gemini15Flash
631 | Self::Gemini20FlashLite
632 | Self::Gemini20Flash => GoogleModelMode::Default,
633 Self::Gemini25FlashLitePreview
634 | Self::Gemini25Flash
635 | Self::Gemini25Pro
636 | Self::Gemini3ProPreview => {
637 GoogleModelMode::Thinking {
638 // By default these models are set to "auto", so we preserve that behavior
639 // but indicate they are capable of thinking mode
640 budget_tokens: None,
641 }
642 }
643 Self::Custom { mode, .. } => *mode,
644 }
645 }
646}
647
648impl std::fmt::Display for Model {
649 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
650 write!(f, "{}", self.id())
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use serde_json::json;
658
659 #[test]
660 fn test_function_call_part_with_signature_serializes_correctly() {
661 let part = FunctionCallPart {
662 function_call: FunctionCall {
663 name: "test_function".to_string(),
664 args: json!({"arg": "value"}),
665 },
666 thought_signature: Some("test_signature".to_string()),
667 };
668
669 let serialized = serde_json::to_value(&part).unwrap();
670
671 assert_eq!(serialized["functionCall"]["name"], "test_function");
672 assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
673 assert_eq!(serialized["thoughtSignature"], "test_signature");
674 }
675
676 #[test]
677 fn test_function_call_part_without_signature_omits_field() {
678 let part = FunctionCallPart {
679 function_call: FunctionCall {
680 name: "test_function".to_string(),
681 args: json!({"arg": "value"}),
682 },
683 thought_signature: None,
684 };
685
686 let serialized = serde_json::to_value(&part).unwrap();
687
688 assert_eq!(serialized["functionCall"]["name"], "test_function");
689 assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
690 // thoughtSignature field should not be present when None
691 assert!(serialized.get("thoughtSignature").is_none());
692 }
693
694 #[test]
695 fn test_function_call_part_deserializes_with_signature() {
696 let json = json!({
697 "functionCall": {
698 "name": "test_function",
699 "args": {"arg": "value"}
700 },
701 "thoughtSignature": "test_signature"
702 });
703
704 let part: FunctionCallPart = serde_json::from_value(json).unwrap();
705
706 assert_eq!(part.function_call.name, "test_function");
707 assert_eq!(part.thought_signature, Some("test_signature".to_string()));
708 }
709
710 #[test]
711 fn test_function_call_part_deserializes_without_signature() {
712 let json = json!({
713 "functionCall": {
714 "name": "test_function",
715 "args": {"arg": "value"}
716 }
717 });
718
719 let part: FunctionCallPart = serde_json::from_value(json).unwrap();
720
721 assert_eq!(part.function_call.name, "test_function");
722 assert_eq!(part.thought_signature, None);
723 }
724
725 #[test]
726 fn test_function_call_part_round_trip() {
727 let original = FunctionCallPart {
728 function_call: FunctionCall {
729 name: "test_function".to_string(),
730 args: json!({"arg": "value", "nested": {"key": "val"}}),
731 },
732 thought_signature: Some("round_trip_signature".to_string()),
733 };
734
735 let serialized = serde_json::to_value(&original).unwrap();
736 let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
737
738 assert_eq!(deserialized.function_call.name, original.function_call.name);
739 assert_eq!(deserialized.function_call.args, original.function_call.args);
740 assert_eq!(deserialized.thought_signature, original.thought_signature);
741 }
742
743 #[test]
744 fn test_function_call_part_with_empty_signature_serializes() {
745 let part = FunctionCallPart {
746 function_call: FunctionCall {
747 name: "test_function".to_string(),
748 args: json!({"arg": "value"}),
749 },
750 thought_signature: Some("".to_string()),
751 };
752
753 let serialized = serde_json::to_value(&part).unwrap();
754
755 // Empty string should still be serialized (normalization happens at a higher level)
756 assert_eq!(serialized["thoughtSignature"], "");
757 }
758}