google_ai.rs

  1use std::mem;
  2
  3use anyhow::{Result, anyhow, bail};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use serde::{Deserialize, Deserializer, Serialize, Serializer};
  7pub use settings::ModelMode as GoogleModelMode;
  8
  9pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 10
 11pub async fn stream_generate_content(
 12    client: &dyn HttpClient,
 13    api_url: &str,
 14    api_key: &str,
 15    mut request: GenerateContentRequest,
 16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 17    let api_key = api_key.trim();
 18    validate_generate_content_request(&request)?;
 19
 20    // The `model` field is emptied as it is provided as a path parameter.
 21    let model_id = mem::take(&mut request.model.model_id);
 22
 23    let uri =
 24        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
 25
 26    let request_builder = HttpRequest::builder()
 27        .method(Method::POST)
 28        .uri(uri)
 29        .header("Content-Type", "application/json");
 30
 31    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
 32    let mut response = client.send(request).await?;
 33    if response.status().is_success() {
 34        let reader = BufReader::new(response.into_body());
 35        Ok(reader
 36            .lines()
 37            .filter_map(|line| async move {
 38                match line {
 39                    Ok(line) => {
 40                        if let Some(line) = line.strip_prefix("data: ") {
 41                            match serde_json::from_str(line) {
 42                                Ok(response) => Some(Ok(response)),
 43                                Err(error) => Some(Err(anyhow!(format!(
 44                                    "Error parsing JSON: {error:?}\n{line:?}"
 45                                )))),
 46                            }
 47                        } else {
 48                            None
 49                        }
 50                    }
 51                    Err(error) => Some(Err(anyhow!(error))),
 52                }
 53            })
 54            .boxed())
 55    } else {
 56        let mut text = String::new();
 57        response.body_mut().read_to_string(&mut text).await?;
 58        Err(anyhow!(
 59            "error during streamGenerateContent, status code: {:?}, body: {}",
 60            response.status(),
 61            text
 62        ))
 63    }
 64}
 65
 66pub async fn count_tokens(
 67    client: &dyn HttpClient,
 68    api_url: &str,
 69    api_key: &str,
 70    request: CountTokensRequest,
 71) -> Result<CountTokensResponse> {
 72    validate_generate_content_request(&request.generate_content_request)?;
 73
 74    let uri = format!(
 75        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
 76        model_id = &request.generate_content_request.model.model_id,
 77    );
 78
 79    let request = serde_json::to_string(&request)?;
 80    let request_builder = HttpRequest::builder()
 81        .method(Method::POST)
 82        .uri(&uri)
 83        .header("Content-Type", "application/json");
 84    let http_request = request_builder.body(AsyncBody::from(request))?;
 85
 86    let mut response = client.send(http_request).await?;
 87    let mut text = String::new();
 88    response.body_mut().read_to_string(&mut text).await?;
 89    anyhow::ensure!(
 90        response.status().is_success(),
 91        "error during countTokens, status code: {:?}, body: {}",
 92        response.status(),
 93        text
 94    );
 95    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 96}
 97
 98pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
 99    if request.model.is_empty() {
100        bail!("Model must be specified");
101    }
102
103    if request.contents.is_empty() {
104        bail!("Request must contain at least one content item");
105    }
106
107    if let Some(user_content) = request
108        .contents
109        .iter()
110        .find(|content| content.role == Role::User)
111        && user_content.parts.is_empty()
112    {
113        bail!("User content must contain at least one part");
114    }
115
116    Ok(())
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub enum Task {
121    #[serde(rename = "generateContent")]
122    GenerateContent,
123    #[serde(rename = "streamGenerateContent")]
124    StreamGenerateContent,
125    #[serde(rename = "countTokens")]
126    CountTokens,
127    #[serde(rename = "embedContent")]
128    EmbedContent,
129    #[serde(rename = "batchEmbedContents")]
130    BatchEmbedContents,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct GenerateContentRequest {
136    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
137    pub model: ModelName,
138    pub contents: Vec<Content>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub system_instruction: Option<SystemInstruction>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub generation_config: Option<GenerationConfig>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub safety_settings: Option<Vec<SafetySetting>>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub tools: Option<Vec<Tool>>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub tool_config: Option<ToolConfig>,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct GenerateContentResponse {
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub candidates: Option<Vec<GenerateContentCandidate>>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub prompt_feedback: Option<PromptFeedback>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub usage_metadata: Option<UsageMetadata>,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct GenerateContentCandidate {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub index: Option<usize>,
167    pub content: Content,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub finish_reason: Option<String>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub finish_message: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub safety_ratings: Option<Vec<SafetyRating>>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub citation_metadata: Option<CitationMetadata>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct Content {
181    #[serde(default)]
182    pub parts: Vec<Part>,
183    pub role: Role,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct SystemInstruction {
189    pub parts: Vec<Part>,
190}
191
192#[derive(Debug, PartialEq, Deserialize, Serialize)]
193#[serde(rename_all = "camelCase")]
194pub enum Role {
195    User,
196    Model,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum Part {
202    TextPart(TextPart),
203    InlineDataPart(InlineDataPart),
204    FunctionCallPart(FunctionCallPart),
205    FunctionResponsePart(FunctionResponsePart),
206    ThoughtPart(ThoughtPart),
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "camelCase")]
211pub struct TextPart {
212    pub text: String,
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct InlineDataPart {
218    pub inline_data: GenerativeContentBlob,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GenerativeContentBlob {
224    pub mime_type: String,
225    pub data: String,
226}
227
228#[derive(Debug, Serialize, Deserialize)]
229#[serde(rename_all = "camelCase")]
230pub struct FunctionCallPart {
231    pub function_call: FunctionCall,
232    /// Thought signature returned by the model for function calls.
233    /// Only present on the first function call in parallel call scenarios.
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub thought_signature: Option<String>,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239#[serde(rename_all = "camelCase")]
240pub struct FunctionResponsePart {
241    pub function_response: FunctionResponse,
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245#[serde(rename_all = "camelCase")]
246pub struct ThoughtPart {
247    pub thought: bool,
248    pub thought_signature: String,
249}
250
251#[derive(Debug, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct CitationSource {
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub start_index: Option<usize>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub end_index: Option<usize>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub uri: Option<String>,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub license: Option<String>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct CitationMetadata {
267    pub citation_sources: Vec<CitationSource>,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271#[serde(rename_all = "camelCase")]
272pub struct PromptFeedback {
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub block_reason: Option<String>,
275    pub safety_ratings: Option<Vec<SafetyRating>>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub block_reason_message: Option<String>,
278}
279
280#[derive(Debug, Serialize, Deserialize, Default)]
281#[serde(rename_all = "camelCase")]
282pub struct UsageMetadata {
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub prompt_token_count: Option<u64>,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub cached_content_token_count: Option<u64>,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub candidates_token_count: Option<u64>,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub tool_use_prompt_token_count: Option<u64>,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub thoughts_token_count: Option<u64>,
293    #[serde(skip_serializing_if = "Option::is_none")]
294    pub total_token_count: Option<u64>,
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298#[serde(rename_all = "camelCase")]
299pub struct ThinkingConfig {
300    pub thinking_budget: u32,
301}
302
303#[derive(Debug, Deserialize, Serialize)]
304#[serde(rename_all = "camelCase")]
305pub struct GenerationConfig {
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub candidate_count: Option<usize>,
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub stop_sequences: Option<Vec<String>>,
310    #[serde(skip_serializing_if = "Option::is_none")]
311    pub max_output_tokens: Option<usize>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub temperature: Option<f64>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub top_p: Option<f64>,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub top_k: Option<usize>,
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub thinking_config: Option<ThinkingConfig>,
320}
321
322#[derive(Debug, Serialize, Deserialize)]
323#[serde(rename_all = "camelCase")]
324pub struct SafetySetting {
325    pub category: HarmCategory,
326    pub threshold: HarmBlockThreshold,
327}
328
329#[derive(Debug, Serialize, Deserialize)]
330pub enum HarmCategory {
331    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
332    Unspecified,
333    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
334    Derogatory,
335    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
336    Toxicity,
337    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
338    Violence,
339    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
340    Sexual,
341    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
342    Medical,
343    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
344    Dangerous,
345    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
346    Harassment,
347    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
348    HateSpeech,
349    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
350    SexuallyExplicit,
351    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
352    DangerousContent,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
357pub enum HarmBlockThreshold {
358    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
359    Unspecified,
360    BlockLowAndAbove,
361    BlockMediumAndAbove,
362    BlockOnlyHigh,
363    BlockNone,
364}
365
366#[derive(Debug, Serialize, Deserialize)]
367#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
368pub enum HarmProbability {
369    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
370    Unspecified,
371    Negligible,
372    Low,
373    Medium,
374    High,
375}
376
377#[derive(Debug, Serialize, Deserialize)]
378#[serde(rename_all = "camelCase")]
379pub struct SafetyRating {
380    pub category: HarmCategory,
381    pub probability: HarmProbability,
382}
383
384#[derive(Debug, Serialize, Deserialize)]
385#[serde(rename_all = "camelCase")]
386pub struct CountTokensRequest {
387    pub generate_content_request: GenerateContentRequest,
388}
389
390#[derive(Debug, Serialize, Deserialize)]
391#[serde(rename_all = "camelCase")]
392pub struct CountTokensResponse {
393    pub total_tokens: u64,
394}
395
396#[derive(Debug, Serialize, Deserialize)]
397pub struct FunctionCall {
398    pub name: String,
399    pub args: serde_json::Value,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403pub struct FunctionResponse {
404    pub name: String,
405    pub response: serde_json::Value,
406}
407
408#[derive(Debug, Serialize, Deserialize)]
409#[serde(rename_all = "camelCase")]
410pub struct Tool {
411    pub function_declarations: Vec<FunctionDeclaration>,
412}
413
414#[derive(Debug, Serialize, Deserialize)]
415#[serde(rename_all = "camelCase")]
416pub struct ToolConfig {
417    pub function_calling_config: FunctionCallingConfig,
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421#[serde(rename_all = "camelCase")]
422pub struct FunctionCallingConfig {
423    pub mode: FunctionCallingMode,
424    #[serde(skip_serializing_if = "Option::is_none")]
425    pub allowed_function_names: Option<Vec<String>>,
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429#[serde(rename_all = "lowercase")]
430pub enum FunctionCallingMode {
431    Auto,
432    Any,
433    None,
434}
435
436#[derive(Debug, Serialize, Deserialize)]
437pub struct FunctionDeclaration {
438    pub name: String,
439    pub description: String,
440    pub parameters: serde_json::Value,
441}
442
443#[derive(Debug, Default)]
444pub struct ModelName {
445    pub model_id: String,
446}
447
448impl ModelName {
449    pub fn is_empty(&self) -> bool {
450        self.model_id.is_empty()
451    }
452}
453
454const MODEL_NAME_PREFIX: &str = "models/";
455
456impl Serialize for ModelName {
457    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
458    where
459        S: Serializer,
460    {
461        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
462    }
463}
464
465impl<'de> Deserialize<'de> for ModelName {
466    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
467    where
468        D: Deserializer<'de>,
469    {
470        let string = String::deserialize(deserializer)?;
471        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
472            Ok(Self {
473                model_id: id.to_string(),
474            })
475        } else {
476            Err(serde::de::Error::custom(format!(
477                "Expected model name to begin with {}, got: {}",
478                MODEL_NAME_PREFIX, string
479            )))
480        }
481    }
482}
483
484#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
485#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
486pub enum Model {
487    #[serde(rename = "gemini-1.5-pro")]
488    Gemini15Pro,
489    #[serde(rename = "gemini-1.5-flash-8b")]
490    Gemini15Flash8b,
491    #[serde(rename = "gemini-1.5-flash")]
492    Gemini15Flash,
493    #[serde(
494        rename = "gemini-2.0-flash-lite",
495        alias = "gemini-2.0-flash-lite-preview"
496    )]
497    Gemini20FlashLite,
498    #[serde(rename = "gemini-2.0-flash")]
499    Gemini20Flash,
500    #[serde(
501        rename = "gemini-2.5-flash-lite-preview",
502        alias = "gemini-2.5-flash-lite-preview-06-17"
503    )]
504    Gemini25FlashLitePreview,
505    #[serde(
506        rename = "gemini-2.5-flash",
507        alias = "gemini-2.0-flash-thinking-exp",
508        alias = "gemini-2.5-flash-preview-04-17",
509        alias = "gemini-2.5-flash-preview-05-20",
510        alias = "gemini-2.5-flash-preview-latest"
511    )]
512    #[default]
513    Gemini25Flash,
514    #[serde(
515        rename = "gemini-2.5-pro",
516        alias = "gemini-2.0-pro-exp",
517        alias = "gemini-2.5-pro-preview-latest",
518        alias = "gemini-2.5-pro-exp-03-25",
519        alias = "gemini-2.5-pro-preview-03-25",
520        alias = "gemini-2.5-pro-preview-05-06",
521        alias = "gemini-2.5-pro-preview-06-05"
522    )]
523    Gemini25Pro,
524    #[serde(rename = "gemini-3-pro-preview")]
525    Gemini3ProPreview,
526    #[serde(rename = "custom")]
527    Custom {
528        name: String,
529        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
530        display_name: Option<String>,
531        max_tokens: u64,
532        #[serde(default)]
533        mode: GoogleModelMode,
534    },
535}
536
537impl Model {
538    pub fn default_fast() -> Self {
539        Self::Gemini20FlashLite
540    }
541
542    pub fn id(&self) -> &str {
543        match self {
544            Self::Gemini15Pro => "gemini-1.5-pro",
545            Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
546            Self::Gemini15Flash => "gemini-1.5-flash",
547            Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
548            Self::Gemini20Flash => "gemini-2.0-flash",
549            Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview",
550            Self::Gemini25Flash => "gemini-2.5-flash",
551            Self::Gemini25Pro => "gemini-2.5-pro",
552            Self::Gemini3ProPreview => "gemini-3-pro-preview",
553            Self::Custom { name, .. } => name,
554        }
555    }
556    pub fn request_id(&self) -> &str {
557        match self {
558            Self::Gemini15Pro => "gemini-1.5-pro",
559            Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
560            Self::Gemini15Flash => "gemini-1.5-flash",
561            Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
562            Self::Gemini20Flash => "gemini-2.0-flash",
563            Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview-06-17",
564            Self::Gemini25Flash => "gemini-2.5-flash",
565            Self::Gemini25Pro => "gemini-2.5-pro",
566            Self::Gemini3ProPreview => "gemini-3-pro-preview",
567            Self::Custom { name, .. } => name,
568        }
569    }
570
571    pub fn display_name(&self) -> &str {
572        match self {
573            Self::Gemini15Pro => "Gemini 1.5 Pro",
574            Self::Gemini15Flash8b => "Gemini 1.5 Flash-8b",
575            Self::Gemini15Flash => "Gemini 1.5 Flash",
576            Self::Gemini20FlashLite => "Gemini 2.0 Flash-Lite",
577            Self::Gemini20Flash => "Gemini 2.0 Flash",
578            Self::Gemini25FlashLitePreview => "Gemini 2.5 Flash-Lite Preview",
579            Self::Gemini25Flash => "Gemini 2.5 Flash",
580            Self::Gemini25Pro => "Gemini 2.5 Pro",
581            Self::Gemini3ProPreview => "Gemini 3 Pro",
582            Self::Custom {
583                name, display_name, ..
584            } => display_name.as_ref().unwrap_or(name),
585        }
586    }
587
588    pub fn max_token_count(&self) -> u64 {
589        match self {
590            Self::Gemini15Pro => 2_097_152,
591            Self::Gemini15Flash8b => 1_048_576,
592            Self::Gemini15Flash => 1_048_576,
593            Self::Gemini20FlashLite => 1_048_576,
594            Self::Gemini20Flash => 1_048_576,
595            Self::Gemini25FlashLitePreview => 1_000_000,
596            Self::Gemini25Flash => 1_048_576,
597            Self::Gemini25Pro => 1_048_576,
598            Self::Gemini3ProPreview => 1_048_576,
599            Self::Custom { max_tokens, .. } => *max_tokens,
600        }
601    }
602
603    pub fn max_output_tokens(&self) -> Option<u64> {
604        match self {
605            Model::Gemini15Pro => Some(8_192),
606            Model::Gemini15Flash8b => Some(8_192),
607            Model::Gemini15Flash => Some(8_192),
608            Model::Gemini20FlashLite => Some(8_192),
609            Model::Gemini20Flash => Some(8_192),
610            Model::Gemini25FlashLitePreview => Some(64_000),
611            Model::Gemini25Flash => Some(65_536),
612            Model::Gemini25Pro => Some(65_536),
613            Model::Gemini3ProPreview => Some(65_536),
614            Model::Custom { .. } => None,
615        }
616    }
617
618    pub fn supports_tools(&self) -> bool {
619        true
620    }
621
622    pub fn supports_images(&self) -> bool {
623        true
624    }
625
626    pub fn mode(&self) -> GoogleModelMode {
627        match self {
628            Self::Gemini15Pro
629            | Self::Gemini15Flash8b
630            | Self::Gemini15Flash
631            | Self::Gemini20FlashLite
632            | Self::Gemini20Flash => GoogleModelMode::Default,
633            Self::Gemini25FlashLitePreview
634            | Self::Gemini25Flash
635            | Self::Gemini25Pro
636            | Self::Gemini3ProPreview => {
637                GoogleModelMode::Thinking {
638                    // By default these models are set to "auto", so we preserve that behavior
639                    // but indicate they are capable of thinking mode
640                    budget_tokens: None,
641                }
642            }
643            Self::Custom { mode, .. } => *mode,
644        }
645    }
646}
647
648impl std::fmt::Display for Model {
649    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
650        write!(f, "{}", self.id())
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use serde_json::json;
658
659    #[test]
660    fn test_function_call_part_with_signature_serializes_correctly() {
661        let part = FunctionCallPart {
662            function_call: FunctionCall {
663                name: "test_function".to_string(),
664                args: json!({"arg": "value"}),
665            },
666            thought_signature: Some("test_signature".to_string()),
667        };
668
669        let serialized = serde_json::to_value(&part).unwrap();
670
671        assert_eq!(serialized["functionCall"]["name"], "test_function");
672        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
673        assert_eq!(serialized["thoughtSignature"], "test_signature");
674    }
675
676    #[test]
677    fn test_function_call_part_without_signature_omits_field() {
678        let part = FunctionCallPart {
679            function_call: FunctionCall {
680                name: "test_function".to_string(),
681                args: json!({"arg": "value"}),
682            },
683            thought_signature: None,
684        };
685
686        let serialized = serde_json::to_value(&part).unwrap();
687
688        assert_eq!(serialized["functionCall"]["name"], "test_function");
689        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
690        // thoughtSignature field should not be present when None
691        assert!(serialized.get("thoughtSignature").is_none());
692    }
693
694    #[test]
695    fn test_function_call_part_deserializes_with_signature() {
696        let json = json!({
697            "functionCall": {
698                "name": "test_function",
699                "args": {"arg": "value"}
700            },
701            "thoughtSignature": "test_signature"
702        });
703
704        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
705
706        assert_eq!(part.function_call.name, "test_function");
707        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
708    }
709
710    #[test]
711    fn test_function_call_part_deserializes_without_signature() {
712        let json = json!({
713            "functionCall": {
714                "name": "test_function",
715                "args": {"arg": "value"}
716            }
717        });
718
719        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
720
721        assert_eq!(part.function_call.name, "test_function");
722        assert_eq!(part.thought_signature, None);
723    }
724
725    #[test]
726    fn test_function_call_part_round_trip() {
727        let original = FunctionCallPart {
728            function_call: FunctionCall {
729                name: "test_function".to_string(),
730                args: json!({"arg": "value", "nested": {"key": "val"}}),
731            },
732            thought_signature: Some("round_trip_signature".to_string()),
733        };
734
735        let serialized = serde_json::to_value(&original).unwrap();
736        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
737
738        assert_eq!(deserialized.function_call.name, original.function_call.name);
739        assert_eq!(deserialized.function_call.args, original.function_call.args);
740        assert_eq!(deserialized.thought_signature, original.thought_signature);
741    }
742
743    #[test]
744    fn test_function_call_part_with_empty_signature_serializes() {
745        let part = FunctionCallPart {
746            function_call: FunctionCall {
747                name: "test_function".to_string(),
748                args: json!({"arg": "value"}),
749            },
750            thought_signature: Some("".to_string()),
751        };
752
753        let serialized = serde_json::to_value(&part).unwrap();
754
755        // Empty string should still be serialized (normalization happens at a higher level)
756        assert_eq!(serialized["thoughtSignature"], "");
757    }
758}