google_ai.rs

  1use std::mem;
  2
  3use anyhow::{Result, anyhow, bail};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6pub use language_model_core::ModelMode as GoogleModelMode;
  7use serde::{Deserialize, Deserializer, Serialize, Serializer};
  8pub mod completion;
  9
 10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 11
 12pub async fn stream_generate_content(
 13    client: &dyn HttpClient,
 14    api_url: &str,
 15    api_key: &str,
 16    mut request: GenerateContentRequest,
 17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 18    let api_key = api_key.trim();
 19    validate_generate_content_request(&request)?;
 20
 21    // The `model` field is emptied as it is provided as a path parameter.
 22    let model_id = mem::take(&mut request.model.model_id);
 23
 24    let uri =
 25        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
 26
 27    let request_builder = HttpRequest::builder()
 28        .method(Method::POST)
 29        .uri(uri)
 30        .header("Content-Type", "application/json");
 31
 32    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
 33    let mut response = client.send(request).await?;
 34    if response.status().is_success() {
 35        let reader = BufReader::new(response.into_body());
 36        Ok(reader
 37            .lines()
 38            .filter_map(|line| async move {
 39                match line {
 40                    Ok(line) => {
 41                        if let Some(line) = line.strip_prefix("data: ") {
 42                            match serde_json::from_str(line) {
 43                                Ok(response) => Some(Ok(response)),
 44                                Err(error) => Some(Err(anyhow!(format!(
 45                                    "Error parsing JSON: {error:?}\n{line:?}"
 46                                )))),
 47                            }
 48                        } else {
 49                            None
 50                        }
 51                    }
 52                    Err(error) => Some(Err(anyhow!(error))),
 53                }
 54            })
 55            .boxed())
 56    } else {
 57        let mut text = String::new();
 58        response.body_mut().read_to_string(&mut text).await?;
 59        Err(anyhow!(
 60            "error during streamGenerateContent, status code: {:?}, body: {}",
 61            response.status(),
 62            text
 63        ))
 64    }
 65}
 66
 67pub async fn count_tokens(
 68    client: &dyn HttpClient,
 69    api_url: &str,
 70    api_key: &str,
 71    request: CountTokensRequest,
 72) -> Result<CountTokensResponse> {
 73    validate_generate_content_request(&request.generate_content_request)?;
 74
 75    let uri = format!(
 76        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
 77        model_id = &request.generate_content_request.model.model_id,
 78    );
 79
 80    let request = serde_json::to_string(&request)?;
 81    let request_builder = HttpRequest::builder()
 82        .method(Method::POST)
 83        .uri(&uri)
 84        .header("Content-Type", "application/json");
 85    let http_request = request_builder.body(AsyncBody::from(request))?;
 86
 87    let mut response = client.send(http_request).await?;
 88    let mut text = String::new();
 89    response.body_mut().read_to_string(&mut text).await?;
 90    anyhow::ensure!(
 91        response.status().is_success(),
 92        "error during countTokens, status code: {:?}, body: {}",
 93        response.status(),
 94        text
 95    );
 96    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 97}
 98
 99pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
100    if request.model.is_empty() {
101        bail!("Model must be specified");
102    }
103
104    if request.contents.is_empty() {
105        bail!("Request must contain at least one content item");
106    }
107
108    if let Some(user_content) = request
109        .contents
110        .iter()
111        .find(|content| content.role == Role::User)
112        && user_content.parts.is_empty()
113    {
114        bail!("User content must contain at least one part");
115    }
116
117    Ok(())
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121pub enum Task {
122    #[serde(rename = "generateContent")]
123    GenerateContent,
124    #[serde(rename = "streamGenerateContent")]
125    StreamGenerateContent,
126    #[serde(rename = "countTokens")]
127    CountTokens,
128    #[serde(rename = "embedContent")]
129    EmbedContent,
130    #[serde(rename = "batchEmbedContents")]
131    BatchEmbedContents,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136pub struct GenerateContentRequest {
137    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
138    pub model: ModelName,
139    pub contents: Vec<Content>,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub system_instruction: Option<SystemInstruction>,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub generation_config: Option<GenerationConfig>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub safety_settings: Option<Vec<SafetySetting>>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub tools: Option<Vec<Tool>>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub tool_config: Option<ToolConfig>,
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153#[serde(rename_all = "camelCase")]
154pub struct GenerateContentResponse {
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub candidates: Option<Vec<GenerateContentCandidate>>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub prompt_feedback: Option<PromptFeedback>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub usage_metadata: Option<UsageMetadata>,
161}
162
163#[derive(Debug, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct GenerateContentCandidate {
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub index: Option<usize>,
168    pub content: Content,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub finish_reason: Option<String>,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub finish_message: Option<String>,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub safety_ratings: Option<Vec<SafetyRating>>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub citation_metadata: Option<CitationMetadata>,
177}
178
179#[derive(Debug, Serialize, Deserialize)]
180#[serde(rename_all = "camelCase")]
181pub struct Content {
182    #[serde(default)]
183    pub parts: Vec<Part>,
184    pub role: Role,
185}
186
187#[derive(Debug, Serialize, Deserialize)]
188#[serde(rename_all = "camelCase")]
189pub struct SystemInstruction {
190    pub parts: Vec<Part>,
191}
192
193#[derive(Debug, PartialEq, Deserialize, Serialize)]
194#[serde(rename_all = "camelCase")]
195pub enum Role {
196    User,
197    Model,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201#[serde(untagged)]
202pub enum Part {
203    TextPart(TextPart),
204    InlineDataPart(InlineDataPart),
205    FunctionCallPart(FunctionCallPart),
206    FunctionResponsePart(FunctionResponsePart),
207    ThoughtPart(ThoughtPart),
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211#[serde(rename_all = "camelCase")]
212pub struct TextPart {
213    pub text: String,
214}
215
216#[derive(Debug, Serialize, Deserialize)]
217#[serde(rename_all = "camelCase")]
218pub struct InlineDataPart {
219    pub inline_data: GenerativeContentBlob,
220}
221
222#[derive(Debug, Serialize, Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct GenerativeContentBlob {
225    pub mime_type: String,
226    pub data: String,
227}
228
229#[derive(Debug, Serialize, Deserialize)]
230#[serde(rename_all = "camelCase")]
231pub struct FunctionCallPart {
232    pub function_call: FunctionCall,
233    /// Thought signature returned by the model for function calls.
234    /// Only present on the first function call in parallel call scenarios.
235    #[serde(skip_serializing_if = "Option::is_none")]
236    pub thought_signature: Option<String>,
237}
238
239#[derive(Debug, Serialize, Deserialize)]
240#[serde(rename_all = "camelCase")]
241pub struct FunctionResponsePart {
242    pub function_response: FunctionResponse,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
246#[serde(rename_all = "camelCase")]
247pub struct ThoughtPart {
248    pub thought: bool,
249    pub thought_signature: String,
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct CitationSource {
255    #[serde(skip_serializing_if = "Option::is_none")]
256    pub start_index: Option<usize>,
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub end_index: Option<usize>,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub uri: Option<String>,
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub license: Option<String>,
263}
264
265#[derive(Debug, Serialize, Deserialize)]
266#[serde(rename_all = "camelCase")]
267pub struct CitationMetadata {
268    pub citation_sources: Vec<CitationSource>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(rename_all = "camelCase")]
273pub struct PromptFeedback {
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub block_reason: Option<String>,
276    pub safety_ratings: Option<Vec<SafetyRating>>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub block_reason_message: Option<String>,
279}
280
281#[derive(Debug, Serialize, Deserialize, Default)]
282#[serde(rename_all = "camelCase")]
283pub struct UsageMetadata {
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub prompt_token_count: Option<u64>,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub cached_content_token_count: Option<u64>,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub candidates_token_count: Option<u64>,
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub tool_use_prompt_token_count: Option<u64>,
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub thoughts_token_count: Option<u64>,
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub total_token_count: Option<u64>,
296}
297
298#[derive(Debug, Serialize, Deserialize)]
299#[serde(rename_all = "camelCase")]
300pub struct ThinkingConfig {
301    pub thinking_budget: u32,
302}
303
304#[derive(Debug, Deserialize, Serialize)]
305#[serde(rename_all = "camelCase")]
306pub struct GenerationConfig {
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub candidate_count: Option<usize>,
309    #[serde(skip_serializing_if = "Option::is_none")]
310    pub stop_sequences: Option<Vec<String>>,
311    #[serde(skip_serializing_if = "Option::is_none")]
312    pub max_output_tokens: Option<usize>,
313    #[serde(skip_serializing_if = "Option::is_none")]
314    pub temperature: Option<f64>,
315    #[serde(skip_serializing_if = "Option::is_none")]
316    pub top_p: Option<f64>,
317    #[serde(skip_serializing_if = "Option::is_none")]
318    pub top_k: Option<usize>,
319    #[serde(skip_serializing_if = "Option::is_none")]
320    pub thinking_config: Option<ThinkingConfig>,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324#[serde(rename_all = "camelCase")]
325pub struct SafetySetting {
326    pub category: HarmCategory,
327    pub threshold: HarmBlockThreshold,
328}
329
330#[derive(Debug, Serialize, Deserialize)]
331pub enum HarmCategory {
332    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
333    Unspecified,
334    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
335    Derogatory,
336    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
337    Toxicity,
338    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
339    Violence,
340    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
341    Sexual,
342    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
343    Medical,
344    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
345    Dangerous,
346    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
347    Harassment,
348    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
349    HateSpeech,
350    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
351    SexuallyExplicit,
352    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
353    DangerousContent,
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
358pub enum HarmBlockThreshold {
359    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
360    Unspecified,
361    BlockLowAndAbove,
362    BlockMediumAndAbove,
363    BlockOnlyHigh,
364    BlockNone,
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
369pub enum HarmProbability {
370    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
371    Unspecified,
372    Negligible,
373    Low,
374    Medium,
375    High,
376}
377
378#[derive(Debug, Serialize, Deserialize)]
379#[serde(rename_all = "camelCase")]
380pub struct SafetyRating {
381    pub category: HarmCategory,
382    pub probability: HarmProbability,
383}
384
385#[derive(Debug, Serialize, Deserialize)]
386#[serde(rename_all = "camelCase")]
387pub struct CountTokensRequest {
388    pub generate_content_request: GenerateContentRequest,
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392#[serde(rename_all = "camelCase")]
393pub struct CountTokensResponse {
394    pub total_tokens: u64,
395}
396
397#[derive(Debug, Serialize, Deserialize)]
398pub struct FunctionCall {
399    pub name: String,
400    pub args: serde_json::Value,
401}
402
403#[derive(Debug, Serialize, Deserialize)]
404pub struct FunctionResponse {
405    pub name: String,
406    pub response: serde_json::Value,
407}
408
409#[derive(Debug, Serialize, Deserialize)]
410#[serde(rename_all = "camelCase")]
411pub struct Tool {
412    pub function_declarations: Vec<FunctionDeclaration>,
413}
414
415#[derive(Debug, Serialize, Deserialize)]
416#[serde(rename_all = "camelCase")]
417pub struct ToolConfig {
418    pub function_calling_config: FunctionCallingConfig,
419}
420
421#[derive(Debug, Serialize, Deserialize)]
422#[serde(rename_all = "camelCase")]
423pub struct FunctionCallingConfig {
424    pub mode: FunctionCallingMode,
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub allowed_function_names: Option<Vec<String>>,
427}
428
429#[derive(Debug, Serialize, Deserialize)]
430#[serde(rename_all = "lowercase")]
431pub enum FunctionCallingMode {
432    Auto,
433    Any,
434    None,
435}
436
437#[derive(Debug, Serialize, Deserialize)]
438pub struct FunctionDeclaration {
439    pub name: String,
440    pub description: String,
441    pub parameters: serde_json::Value,
442}
443
444#[derive(Debug, Default)]
445pub struct ModelName {
446    pub model_id: String,
447}
448
449impl ModelName {
450    pub fn is_empty(&self) -> bool {
451        self.model_id.is_empty()
452    }
453}
454
455const MODEL_NAME_PREFIX: &str = "models/";
456
457impl Serialize for ModelName {
458    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
459    where
460        S: Serializer,
461    {
462        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
463    }
464}
465
466impl<'de> Deserialize<'de> for ModelName {
467    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
468    where
469        D: Deserializer<'de>,
470    {
471        let string = String::deserialize(deserializer)?;
472        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
473            Ok(Self {
474                model_id: id.to_string(),
475            })
476        } else {
477            Err(serde::de::Error::custom(format!(
478                "Expected model name to begin with {}, got: {}",
479                MODEL_NAME_PREFIX, string
480            )))
481        }
482    }
483}
484
485#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
486#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
487pub enum Model {
488    #[serde(
489        rename = "gemini-2.5-flash-lite",
490        alias = "gemini-2.5-flash-lite-preview-06-17",
491        alias = "gemini-2.0-flash-lite-preview"
492    )]
493    Gemini25FlashLite,
494    #[serde(
495        rename = "gemini-2.5-flash",
496        alias = "gemini-2.0-flash-thinking-exp",
497        alias = "gemini-2.5-flash-preview-04-17",
498        alias = "gemini-2.5-flash-preview-05-20",
499        alias = "gemini-2.5-flash-preview-latest",
500        alias = "gemini-2.0-flash"
501    )]
502    #[default]
503    Gemini25Flash,
504    #[serde(
505        rename = "gemini-2.5-pro",
506        alias = "gemini-2.0-pro-exp",
507        alias = "gemini-2.5-pro-preview-latest",
508        alias = "gemini-2.5-pro-exp-03-25",
509        alias = "gemini-2.5-pro-preview-03-25",
510        alias = "gemini-2.5-pro-preview-05-06",
511        alias = "gemini-2.5-pro-preview-06-05"
512    )]
513    Gemini25Pro,
514    #[serde(rename = "gemini-3-flash-preview")]
515    Gemini3Flash,
516    #[serde(rename = "gemini-3.1-pro-preview", alias = "gemini-3-pro-preview")]
517    Gemini31Pro,
518    #[serde(rename = "custom")]
519    Custom {
520        name: String,
521        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
522        display_name: Option<String>,
523        max_tokens: u64,
524        #[serde(default)]
525        mode: GoogleModelMode,
526    },
527}
528
529impl Model {
530    pub fn default_fast() -> Self {
531        Self::Gemini25FlashLite
532    }
533
534    pub fn id(&self) -> &str {
535        match self {
536            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
537            Self::Gemini25Flash => "gemini-2.5-flash",
538            Self::Gemini25Pro => "gemini-2.5-pro",
539            Self::Gemini3Flash => "gemini-3-flash-preview",
540            Self::Gemini31Pro => "gemini-3.1-pro-preview",
541            Self::Custom { name, .. } => name,
542        }
543    }
544    pub fn request_id(&self) -> &str {
545        match self {
546            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
547            Self::Gemini25Flash => "gemini-2.5-flash",
548            Self::Gemini25Pro => "gemini-2.5-pro",
549            Self::Gemini3Flash => "gemini-3-flash-preview",
550            Self::Gemini31Pro => "gemini-3.1-pro-preview",
551            Self::Custom { name, .. } => name,
552        }
553    }
554
555    pub fn display_name(&self) -> &str {
556        match self {
557            Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
558            Self::Gemini25Flash => "Gemini 2.5 Flash",
559            Self::Gemini25Pro => "Gemini 2.5 Pro",
560            Self::Gemini3Flash => "Gemini 3 Flash",
561            Self::Gemini31Pro => "Gemini 3.1 Pro",
562            Self::Custom {
563                name, display_name, ..
564            } => display_name.as_ref().unwrap_or(name),
565        }
566    }
567
568    pub fn max_token_count(&self) -> u64 {
569        match self {
570            Self::Gemini25FlashLite
571            | Self::Gemini25Flash
572            | Self::Gemini25Pro
573            | Self::Gemini3Flash
574            | Self::Gemini31Pro => 1_048_576,
575            Self::Custom { max_tokens, .. } => *max_tokens,
576        }
577    }
578
579    pub fn max_output_tokens(&self) -> Option<u64> {
580        match self {
581            Model::Gemini25FlashLite
582            | Model::Gemini25Flash
583            | Model::Gemini25Pro
584            | Model::Gemini3Flash
585            | Model::Gemini31Pro => Some(65_536),
586            Model::Custom { .. } => None,
587        }
588    }
589
590    pub fn supports_tools(&self) -> bool {
591        true
592    }
593
594    pub fn supports_images(&self) -> bool {
595        true
596    }
597
598    pub fn mode(&self) -> GoogleModelMode {
599        match self {
600            Self::Gemini25FlashLite | Self::Gemini25Flash | Self::Gemini25Pro => {
601                GoogleModelMode::Thinking {
602                    // By default these models are set to "auto", so we preserve that behavior
603                    // but indicate they are capable of thinking mode
604                    budget_tokens: None,
605                }
606            }
607            Self::Gemini3Flash => GoogleModelMode::Default,
608            Self::Gemini31Pro => GoogleModelMode::Thinking {
609                budget_tokens: None,
610            },
611            Self::Custom { mode, .. } => *mode,
612        }
613    }
614}
615
616impl std::fmt::Display for Model {
617    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
618        write!(f, "{}", self.id())
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use serde_json::json;
626
627    #[test]
628    fn test_function_call_part_with_signature_serializes_correctly() {
629        let part = FunctionCallPart {
630            function_call: FunctionCall {
631                name: "test_function".to_string(),
632                args: json!({"arg": "value"}),
633            },
634            thought_signature: Some("test_signature".to_string()),
635        };
636
637        let serialized = serde_json::to_value(&part).unwrap();
638
639        assert_eq!(serialized["functionCall"]["name"], "test_function");
640        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
641        assert_eq!(serialized["thoughtSignature"], "test_signature");
642    }
643
644    #[test]
645    fn test_function_call_part_without_signature_omits_field() {
646        let part = FunctionCallPart {
647            function_call: FunctionCall {
648                name: "test_function".to_string(),
649                args: json!({"arg": "value"}),
650            },
651            thought_signature: None,
652        };
653
654        let serialized = serde_json::to_value(&part).unwrap();
655
656        assert_eq!(serialized["functionCall"]["name"], "test_function");
657        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
658        // thoughtSignature field should not be present when None
659        assert!(serialized.get("thoughtSignature").is_none());
660    }
661
662    #[test]
663    fn test_function_call_part_deserializes_with_signature() {
664        let json = json!({
665            "functionCall": {
666                "name": "test_function",
667                "args": {"arg": "value"}
668            },
669            "thoughtSignature": "test_signature"
670        });
671
672        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
673
674        assert_eq!(part.function_call.name, "test_function");
675        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
676    }
677
678    #[test]
679    fn test_function_call_part_deserializes_without_signature() {
680        let json = json!({
681            "functionCall": {
682                "name": "test_function",
683                "args": {"arg": "value"}
684            }
685        });
686
687        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
688
689        assert_eq!(part.function_call.name, "test_function");
690        assert_eq!(part.thought_signature, None);
691    }
692
693    #[test]
694    fn test_function_call_part_round_trip() {
695        let original = FunctionCallPart {
696            function_call: FunctionCall {
697                name: "test_function".to_string(),
698                args: json!({"arg": "value", "nested": {"key": "val"}}),
699            },
700            thought_signature: Some("round_trip_signature".to_string()),
701        };
702
703        let serialized = serde_json::to_value(&original).unwrap();
704        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
705
706        assert_eq!(deserialized.function_call.name, original.function_call.name);
707        assert_eq!(deserialized.function_call.args, original.function_call.args);
708        assert_eq!(deserialized.thought_signature, original.thought_signature);
709    }
710
711    #[test]
712    fn test_function_call_part_with_empty_signature_serializes() {
713        let part = FunctionCallPart {
714            function_call: FunctionCall {
715                name: "test_function".to_string(),
716                args: json!({"arg": "value"}),
717            },
718            thought_signature: Some("".to_string()),
719        };
720
721        let serialized = serde_json::to_value(&part).unwrap();
722
723        // Empty string should still be serialized (normalization happens at a higher level)
724        assert_eq!(serialized["thoughtSignature"], "");
725    }
726}