google_ai.rs

  1use std::mem;
  2
  3use anyhow::{Result, anyhow, bail};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use serde::{Deserialize, Deserializer, Serialize, Serializer};
  7pub use settings::ModelMode as GoogleModelMode;
  8
  9pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 10
 11pub async fn stream_generate_content(
 12    client: &dyn HttpClient,
 13    api_url: &str,
 14    api_key: &str,
 15    mut request: GenerateContentRequest,
 16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 17    let api_key = api_key.trim();
 18    validate_generate_content_request(&request)?;
 19
 20    // The `model` field is emptied as it is provided as a path parameter.
 21    let model_id = mem::take(&mut request.model.model_id);
 22
 23    let uri =
 24        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
 25
 26    let request_builder = HttpRequest::builder()
 27        .method(Method::POST)
 28        .uri(uri)
 29        .header("Content-Type", "application/json");
 30
 31    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
 32    let mut response = client.send(request).await?;
 33    if response.status().is_success() {
 34        let reader = BufReader::new(response.into_body());
 35        Ok(reader
 36            .lines()
 37            .filter_map(|line| async move {
 38                match line {
 39                    Ok(line) => {
 40                        if let Some(line) = line.strip_prefix("data: ") {
 41                            match serde_json::from_str(line) {
 42                                Ok(response) => Some(Ok(response)),
 43                                Err(error) => Some(Err(anyhow!(format!(
 44                                    "Error parsing JSON: {error:?}\n{line:?}"
 45                                )))),
 46                            }
 47                        } else {
 48                            None
 49                        }
 50                    }
 51                    Err(error) => Some(Err(anyhow!(error))),
 52                }
 53            })
 54            .boxed())
 55    } else {
 56        let mut text = String::new();
 57        response.body_mut().read_to_string(&mut text).await?;
 58        Err(anyhow!(
 59            "error during streamGenerateContent, status code: {:?}, body: {}",
 60            response.status(),
 61            text
 62        ))
 63    }
 64}
 65
 66pub async fn count_tokens(
 67    client: &dyn HttpClient,
 68    api_url: &str,
 69    api_key: &str,
 70    request: CountTokensRequest,
 71) -> Result<CountTokensResponse> {
 72    validate_generate_content_request(&request.generate_content_request)?;
 73
 74    let uri = format!(
 75        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
 76        model_id = &request.generate_content_request.model.model_id,
 77    );
 78
 79    let request = serde_json::to_string(&request)?;
 80    let request_builder = HttpRequest::builder()
 81        .method(Method::POST)
 82        .uri(&uri)
 83        .header("Content-Type", "application/json");
 84    let http_request = request_builder.body(AsyncBody::from(request))?;
 85
 86    let mut response = client.send(http_request).await?;
 87    let mut text = String::new();
 88    response.body_mut().read_to_string(&mut text).await?;
 89    anyhow::ensure!(
 90        response.status().is_success(),
 91        "error during countTokens, status code: {:?}, body: {}",
 92        response.status(),
 93        text
 94    );
 95    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 96}
 97
 98pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
 99    if request.model.is_empty() {
100        bail!("Model must be specified");
101    }
102
103    if request.contents.is_empty() {
104        bail!("Request must contain at least one content item");
105    }
106
107    if let Some(user_content) = request
108        .contents
109        .iter()
110        .find(|content| content.role == Role::User)
111        && user_content.parts.is_empty()
112    {
113        bail!("User content must contain at least one part");
114    }
115
116    Ok(())
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub enum Task {
121    #[serde(rename = "generateContent")]
122    GenerateContent,
123    #[serde(rename = "streamGenerateContent")]
124    StreamGenerateContent,
125    #[serde(rename = "countTokens")]
126    CountTokens,
127    #[serde(rename = "embedContent")]
128    EmbedContent,
129    #[serde(rename = "batchEmbedContents")]
130    BatchEmbedContents,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct GenerateContentRequest {
136    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
137    pub model: ModelName,
138    pub contents: Vec<Content>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub system_instruction: Option<SystemInstruction>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub generation_config: Option<GenerationConfig>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub safety_settings: Option<Vec<SafetySetting>>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub tools: Option<Vec<Tool>>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub tool_config: Option<ToolConfig>,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct GenerateContentResponse {
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub candidates: Option<Vec<GenerateContentCandidate>>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub prompt_feedback: Option<PromptFeedback>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub usage_metadata: Option<UsageMetadata>,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct GenerateContentCandidate {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub index: Option<usize>,
167    pub content: Content,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub finish_reason: Option<String>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub finish_message: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub safety_ratings: Option<Vec<SafetyRating>>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub citation_metadata: Option<CitationMetadata>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct Content {
181    #[serde(default)]
182    pub parts: Vec<Part>,
183    pub role: Role,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct SystemInstruction {
189    pub parts: Vec<Part>,
190}
191
192#[derive(Debug, PartialEq, Deserialize, Serialize)]
193#[serde(rename_all = "camelCase")]
194pub enum Role {
195    User,
196    Model,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum Part {
202    TextPart(TextPart),
203    InlineDataPart(InlineDataPart),
204    FunctionCallPart(FunctionCallPart),
205    FunctionResponsePart(FunctionResponsePart),
206    ThoughtPart(ThoughtPart),
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "camelCase")]
211pub struct TextPart {
212    pub text: String,
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct InlineDataPart {
218    pub inline_data: GenerativeContentBlob,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GenerativeContentBlob {
224    pub mime_type: String,
225    pub data: String,
226}
227
228#[derive(Debug, Serialize, Deserialize)]
229#[serde(rename_all = "camelCase")]
230pub struct FunctionCallPart {
231    pub function_call: FunctionCall,
232    /// Thought signature returned by the model for function calls.
233    /// Only present on the first function call in parallel call scenarios.
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub thought_signature: Option<String>,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239#[serde(rename_all = "camelCase")]
240pub struct FunctionResponsePart {
241    pub function_response: FunctionResponse,
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245#[serde(rename_all = "camelCase")]
246pub struct ThoughtPart {
247    pub thought: bool,
248    pub thought_signature: String,
249}
250
251#[derive(Debug, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct CitationSource {
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub start_index: Option<usize>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub end_index: Option<usize>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub uri: Option<String>,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub license: Option<String>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct CitationMetadata {
267    pub citation_sources: Vec<CitationSource>,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271#[serde(rename_all = "camelCase")]
272pub struct PromptFeedback {
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub block_reason: Option<String>,
275    pub safety_ratings: Option<Vec<SafetyRating>>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub block_reason_message: Option<String>,
278}
279
280#[derive(Debug, Serialize, Deserialize, Default)]
281#[serde(rename_all = "camelCase")]
282pub struct UsageMetadata {
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub prompt_token_count: Option<u64>,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub cached_content_token_count: Option<u64>,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub candidates_token_count: Option<u64>,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub tool_use_prompt_token_count: Option<u64>,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub thoughts_token_count: Option<u64>,
293    #[serde(skip_serializing_if = "Option::is_none")]
294    pub total_token_count: Option<u64>,
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298#[serde(rename_all = "camelCase")]
299pub struct ThinkingConfig {
300    pub thinking_budget: u32,
301}
302
303#[derive(Debug, Deserialize, Serialize)]
304#[serde(rename_all = "camelCase")]
305pub struct GenerationConfig {
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub candidate_count: Option<usize>,
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub stop_sequences: Option<Vec<String>>,
310    #[serde(skip_serializing_if = "Option::is_none")]
311    pub max_output_tokens: Option<usize>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub temperature: Option<f64>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub top_p: Option<f64>,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub top_k: Option<usize>,
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub thinking_config: Option<ThinkingConfig>,
320}
321
322#[derive(Debug, Serialize, Deserialize)]
323#[serde(rename_all = "camelCase")]
324pub struct SafetySetting {
325    pub category: HarmCategory,
326    pub threshold: HarmBlockThreshold,
327}
328
329#[derive(Debug, Serialize, Deserialize)]
330pub enum HarmCategory {
331    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
332    Unspecified,
333    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
334    Derogatory,
335    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
336    Toxicity,
337    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
338    Violence,
339    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
340    Sexual,
341    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
342    Medical,
343    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
344    Dangerous,
345    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
346    Harassment,
347    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
348    HateSpeech,
349    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
350    SexuallyExplicit,
351    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
352    DangerousContent,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
357pub enum HarmBlockThreshold {
358    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
359    Unspecified,
360    BlockLowAndAbove,
361    BlockMediumAndAbove,
362    BlockOnlyHigh,
363    BlockNone,
364}
365
366#[derive(Debug, Serialize, Deserialize)]
367#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
368pub enum HarmProbability {
369    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
370    Unspecified,
371    Negligible,
372    Low,
373    Medium,
374    High,
375}
376
377#[derive(Debug, Serialize, Deserialize)]
378#[serde(rename_all = "camelCase")]
379pub struct SafetyRating {
380    pub category: HarmCategory,
381    pub probability: HarmProbability,
382}
383
384#[derive(Debug, Serialize, Deserialize)]
385#[serde(rename_all = "camelCase")]
386pub struct CountTokensRequest {
387    pub generate_content_request: GenerateContentRequest,
388}
389
390#[derive(Debug, Serialize, Deserialize)]
391#[serde(rename_all = "camelCase")]
392pub struct CountTokensResponse {
393    pub total_tokens: u64,
394}
395
396#[derive(Debug, Serialize, Deserialize)]
397pub struct FunctionCall {
398    pub name: String,
399    pub args: serde_json::Value,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403pub struct FunctionResponse {
404    pub name: String,
405    pub response: serde_json::Value,
406}
407
408#[derive(Debug, Serialize, Deserialize)]
409#[serde(rename_all = "camelCase")]
410pub struct Tool {
411    pub function_declarations: Vec<FunctionDeclaration>,
412}
413
414#[derive(Debug, Serialize, Deserialize)]
415#[serde(rename_all = "camelCase")]
416pub struct ToolConfig {
417    pub function_calling_config: FunctionCallingConfig,
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421#[serde(rename_all = "camelCase")]
422pub struct FunctionCallingConfig {
423    pub mode: FunctionCallingMode,
424    #[serde(skip_serializing_if = "Option::is_none")]
425    pub allowed_function_names: Option<Vec<String>>,
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429#[serde(rename_all = "lowercase")]
430pub enum FunctionCallingMode {
431    Auto,
432    Any,
433    None,
434}
435
436#[derive(Debug, Serialize, Deserialize)]
437pub struct FunctionDeclaration {
438    pub name: String,
439    pub description: String,
440    pub parameters: serde_json::Value,
441}
442
443#[derive(Debug, Default)]
444pub struct ModelName {
445    pub model_id: String,
446}
447
448impl ModelName {
449    pub fn is_empty(&self) -> bool {
450        self.model_id.is_empty()
451    }
452}
453
454const MODEL_NAME_PREFIX: &str = "models/";
455
456impl Serialize for ModelName {
457    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
458    where
459        S: Serializer,
460    {
461        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
462    }
463}
464
465impl<'de> Deserialize<'de> for ModelName {
466    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
467    where
468        D: Deserializer<'de>,
469    {
470        let string = String::deserialize(deserializer)?;
471        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
472            Ok(Self {
473                model_id: id.to_string(),
474            })
475        } else {
476            Err(serde::de::Error::custom(format!(
477                "Expected model name to begin with {}, got: {}",
478                MODEL_NAME_PREFIX, string
479            )))
480        }
481    }
482}
483
484#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
485#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
486pub enum Model {
487    #[serde(
488        rename = "gemini-2.5-flash-lite",
489        alias = "gemini-2.5-flash-lite-preview-06-17",
490        alias = "gemini-2.0-flash-lite-preview"
491    )]
492    Gemini25FlashLite,
493    #[serde(
494        rename = "gemini-2.5-flash",
495        alias = "gemini-2.0-flash-thinking-exp",
496        alias = "gemini-2.5-flash-preview-04-17",
497        alias = "gemini-2.5-flash-preview-05-20",
498        alias = "gemini-2.5-flash-preview-latest",
499        alias = "gemini-2.0-flash"
500    )]
501    #[default]
502    Gemini25Flash,
503    #[serde(
504        rename = "gemini-2.5-pro",
505        alias = "gemini-2.0-pro-exp",
506        alias = "gemini-2.5-pro-preview-latest",
507        alias = "gemini-2.5-pro-exp-03-25",
508        alias = "gemini-2.5-pro-preview-03-25",
509        alias = "gemini-2.5-pro-preview-05-06",
510        alias = "gemini-2.5-pro-preview-06-05"
511    )]
512    Gemini25Pro,
513    #[serde(rename = "gemini-3-pro-preview")]
514    Gemini3Pro,
515    #[serde(rename = "gemini-3-flash-preview")]
516    Gemini3Flash,
517    #[serde(rename = "custom")]
518    Custom {
519        name: String,
520        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
521        display_name: Option<String>,
522        max_tokens: u64,
523        #[serde(default)]
524        mode: GoogleModelMode,
525    },
526}
527
528impl Model {
529    pub fn default_fast() -> Self {
530        Self::Gemini25FlashLite
531    }
532
533    pub fn id(&self) -> &str {
534        match self {
535            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
536            Self::Gemini25Flash => "gemini-2.5-flash",
537            Self::Gemini25Pro => "gemini-2.5-pro",
538            Self::Gemini3Pro => "gemini-3-pro-preview",
539            Self::Gemini3Flash => "gemini-3-flash-preview",
540            Self::Custom { name, .. } => name,
541        }
542    }
543    pub fn request_id(&self) -> &str {
544        match self {
545            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
546            Self::Gemini25Flash => "gemini-2.5-flash",
547            Self::Gemini25Pro => "gemini-2.5-pro",
548            Self::Gemini3Pro => "gemini-3-pro-preview",
549            Self::Gemini3Flash => "gemini-3-flash-preview",
550            Self::Custom { name, .. } => name,
551        }
552    }
553
554    pub fn display_name(&self) -> &str {
555        match self {
556            Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
557            Self::Gemini25Flash => "Gemini 2.5 Flash",
558            Self::Gemini25Pro => "Gemini 2.5 Pro",
559            Self::Gemini3Pro => "Gemini 3 Pro",
560            Self::Gemini3Flash => "Gemini 3 Flash",
561            Self::Custom {
562                name, display_name, ..
563            } => display_name.as_ref().unwrap_or(name),
564        }
565    }
566
567    pub fn max_token_count(&self) -> u64 {
568        match self {
569            Self::Gemini25FlashLite => 1_048_576,
570            Self::Gemini25Flash => 1_048_576,
571            Self::Gemini25Pro => 1_048_576,
572            Self::Gemini3Pro => 1_048_576,
573            Self::Gemini3Flash => 1_048_576,
574            Self::Custom { max_tokens, .. } => *max_tokens,
575        }
576    }
577
578    pub fn max_output_tokens(&self) -> Option<u64> {
579        match self {
580            Model::Gemini25FlashLite => Some(65_536),
581            Model::Gemini25Flash => Some(65_536),
582            Model::Gemini25Pro => Some(65_536),
583            Model::Gemini3Pro => Some(65_536),
584            Model::Gemini3Flash => Some(65_536),
585            Model::Custom { .. } => None,
586        }
587    }
588
589    pub fn supports_tools(&self) -> bool {
590        true
591    }
592
593    pub fn supports_images(&self) -> bool {
594        true
595    }
596
597    pub fn mode(&self) -> GoogleModelMode {
598        match self {
599            Self::Gemini25FlashLite
600            | Self::Gemini25Flash
601            | Self::Gemini25Pro
602            | Self::Gemini3Pro => {
603                GoogleModelMode::Thinking {
604                    // By default these models are set to "auto", so we preserve that behavior
605                    // but indicate they are capable of thinking mode
606                    budget_tokens: None,
607                }
608            }
609            Self::Gemini3Flash => GoogleModelMode::Default,
610            Self::Custom { mode, .. } => *mode,
611        }
612    }
613}
614
615impl std::fmt::Display for Model {
616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        write!(f, "{}", self.id())
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use serde_json::json;
625
626    #[test]
627    fn test_function_call_part_with_signature_serializes_correctly() {
628        let part = FunctionCallPart {
629            function_call: FunctionCall {
630                name: "test_function".to_string(),
631                args: json!({"arg": "value"}),
632            },
633            thought_signature: Some("test_signature".to_string()),
634        };
635
636        let serialized = serde_json::to_value(&part).unwrap();
637
638        assert_eq!(serialized["functionCall"]["name"], "test_function");
639        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
640        assert_eq!(serialized["thoughtSignature"], "test_signature");
641    }
642
643    #[test]
644    fn test_function_call_part_without_signature_omits_field() {
645        let part = FunctionCallPart {
646            function_call: FunctionCall {
647                name: "test_function".to_string(),
648                args: json!({"arg": "value"}),
649            },
650            thought_signature: None,
651        };
652
653        let serialized = serde_json::to_value(&part).unwrap();
654
655        assert_eq!(serialized["functionCall"]["name"], "test_function");
656        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
657        // thoughtSignature field should not be present when None
658        assert!(serialized.get("thoughtSignature").is_none());
659    }
660
661    #[test]
662    fn test_function_call_part_deserializes_with_signature() {
663        let json = json!({
664            "functionCall": {
665                "name": "test_function",
666                "args": {"arg": "value"}
667            },
668            "thoughtSignature": "test_signature"
669        });
670
671        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
672
673        assert_eq!(part.function_call.name, "test_function");
674        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
675    }
676
677    #[test]
678    fn test_function_call_part_deserializes_without_signature() {
679        let json = json!({
680            "functionCall": {
681                "name": "test_function",
682                "args": {"arg": "value"}
683            }
684        });
685
686        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
687
688        assert_eq!(part.function_call.name, "test_function");
689        assert_eq!(part.thought_signature, None);
690    }
691
692    #[test]
693    fn test_function_call_part_round_trip() {
694        let original = FunctionCallPart {
695            function_call: FunctionCall {
696                name: "test_function".to_string(),
697                args: json!({"arg": "value", "nested": {"key": "val"}}),
698            },
699            thought_signature: Some("round_trip_signature".to_string()),
700        };
701
702        let serialized = serde_json::to_value(&original).unwrap();
703        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
704
705        assert_eq!(deserialized.function_call.name, original.function_call.name);
706        assert_eq!(deserialized.function_call.args, original.function_call.args);
707        assert_eq!(deserialized.thought_signature, original.thought_signature);
708    }
709
710    #[test]
711    fn test_function_call_part_with_empty_signature_serializes() {
712        let part = FunctionCallPart {
713            function_call: FunctionCall {
714                name: "test_function".to_string(),
715                args: json!({"arg": "value"}),
716            },
717            thought_signature: Some("".to_string()),
718        };
719
720        let serialized = serde_json::to_value(&part).unwrap();
721
722        // Empty string should still be serialized (normalization happens at a higher level)
723        assert_eq!(serialized["thoughtSignature"], "");
724    }
725}