google_ai.rs

  1use std::mem;
  2
  3use anyhow::{Result, anyhow, bail};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6pub use language_model_core::ModelMode as GoogleModelMode;
  7use serde::{Deserialize, Deserializer, Serialize, Serializer};
  8pub mod completion;
  9
 10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 11
 12pub async fn stream_generate_content(
 13    client: &dyn HttpClient,
 14    api_url: &str,
 15    api_key: &str,
 16    mut request: GenerateContentRequest,
 17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 18    let api_key = api_key.trim();
 19    validate_generate_content_request(&request)?;
 20
 21    // The `model` field is emptied as it is provided as a path parameter.
 22    let model_id = mem::take(&mut request.model.model_id);
 23
 24    let uri =
 25        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
 26
 27    let request_builder = HttpRequest::builder()
 28        .method(Method::POST)
 29        .uri(uri)
 30        .header("Content-Type", "application/json");
 31
 32    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
 33    let mut response = client.send(request).await?;
 34    if response.status().is_success() {
 35        let reader = BufReader::new(response.into_body());
 36        Ok(reader
 37            .lines()
 38            .filter_map(|line| async move {
 39                match line {
 40                    Ok(line) => {
 41                        if let Some(line) = line.strip_prefix("data: ") {
 42                            match serde_json::from_str(line) {
 43                                Ok(response) => Some(Ok(response)),
 44                                Err(error) => Some(Err(anyhow!(format!(
 45                                    "Error parsing JSON: {error:?}\n{line:?}"
 46                                )))),
 47                            }
 48                        } else {
 49                            None
 50                        }
 51                    }
 52                    Err(error) => Some(Err(anyhow!(error))),
 53                }
 54            })
 55            .boxed())
 56    } else {
 57        let mut text = String::new();
 58        response.body_mut().read_to_string(&mut text).await?;
 59        Err(anyhow!(
 60            "error during streamGenerateContent, status code: {:?}, body: {}",
 61            response.status(),
 62            text
 63        ))
 64    }
 65}
 66
 67pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
 68    if request.model.is_empty() {
 69        bail!("Model must be specified");
 70    }
 71
 72    if request.contents.is_empty() {
 73        bail!("Request must contain at least one content item");
 74    }
 75
 76    if let Some(user_content) = request
 77        .contents
 78        .iter()
 79        .find(|content| content.role == Role::User)
 80        && user_content.parts.is_empty()
 81    {
 82        bail!("User content must contain at least one part");
 83    }
 84
 85    Ok(())
 86}
 87
 88#[derive(Debug, Serialize, Deserialize)]
 89pub enum Task {
 90    #[serde(rename = "generateContent")]
 91    GenerateContent,
 92    #[serde(rename = "streamGenerateContent")]
 93    StreamGenerateContent,
 94    #[serde(rename = "embedContent")]
 95    EmbedContent,
 96    #[serde(rename = "batchEmbedContents")]
 97    BatchEmbedContents,
 98}
 99
100#[derive(Debug, Serialize, Deserialize)]
101#[serde(rename_all = "camelCase")]
102pub struct GenerateContentRequest {
103    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
104    pub model: ModelName,
105    pub contents: Vec<Content>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub system_instruction: Option<SystemInstruction>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub generation_config: Option<GenerationConfig>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub safety_settings: Option<Vec<SafetySetting>>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub tools: Option<Vec<Tool>>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub tool_config: Option<ToolConfig>,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct GenerateContentResponse {
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub candidates: Option<Vec<GenerateContentCandidate>>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub prompt_feedback: Option<PromptFeedback>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub usage_metadata: Option<UsageMetadata>,
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130#[serde(rename_all = "camelCase")]
131pub struct GenerateContentCandidate {
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub index: Option<usize>,
134    pub content: Content,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub finish_reason: Option<String>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub finish_message: Option<String>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub safety_ratings: Option<Vec<SafetyRating>>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub citation_metadata: Option<CitationMetadata>,
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct Content {
148    #[serde(default)]
149    pub parts: Vec<Part>,
150    pub role: Role,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct SystemInstruction {
156    pub parts: Vec<Part>,
157}
158
159#[derive(Debug, PartialEq, Deserialize, Serialize)]
160#[serde(rename_all = "camelCase")]
161pub enum Role {
162    User,
163    Model,
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167#[serde(untagged)]
168pub enum Part {
169    TextPart(TextPart),
170    InlineDataPart(InlineDataPart),
171    FunctionCallPart(FunctionCallPart),
172    FunctionResponsePart(FunctionResponsePart),
173    ThoughtPart(ThoughtPart),
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct TextPart {
179    pub text: String,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183#[serde(rename_all = "camelCase")]
184pub struct InlineDataPart {
185    pub inline_data: GenerativeContentBlob,
186}
187
188#[derive(Debug, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct GenerativeContentBlob {
191    pub mime_type: String,
192    pub data: String,
193}
194
195#[derive(Debug, Serialize, Deserialize)]
196#[serde(rename_all = "camelCase")]
197pub struct FunctionCallPart {
198    pub function_call: FunctionCall,
199    /// Thought signature returned by the model for function calls.
200    /// Only present on the first function call in parallel call scenarios.
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub thought_signature: Option<String>,
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206#[serde(rename_all = "camelCase")]
207pub struct FunctionResponsePart {
208    pub function_response: FunctionResponse,
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct ThoughtPart {
214    pub thought: bool,
215    pub thought_signature: String,
216}
217
218#[derive(Debug, Serialize, Deserialize)]
219#[serde(rename_all = "camelCase")]
220pub struct CitationSource {
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub start_index: Option<usize>,
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub end_index: Option<usize>,
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub uri: Option<String>,
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub license: Option<String>,
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232#[serde(rename_all = "camelCase")]
233pub struct CitationMetadata {
234    pub citation_sources: Vec<CitationSource>,
235}
236
237#[derive(Debug, Serialize, Deserialize)]
238#[serde(rename_all = "camelCase")]
239pub struct PromptFeedback {
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub block_reason: Option<String>,
242    pub safety_ratings: Option<Vec<SafetyRating>>,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub block_reason_message: Option<String>,
245}
246
247#[derive(Debug, Serialize, Deserialize, Default)]
248#[serde(rename_all = "camelCase")]
249pub struct UsageMetadata {
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub prompt_token_count: Option<u64>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub cached_content_token_count: Option<u64>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub candidates_token_count: Option<u64>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub tool_use_prompt_token_count: Option<u64>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub thoughts_token_count: Option<u64>,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub total_token_count: Option<u64>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct ThinkingConfig {
267    pub thinking_budget: u32,
268}
269
270#[derive(Debug, Deserialize, Serialize)]
271#[serde(rename_all = "camelCase")]
272pub struct GenerationConfig {
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub candidate_count: Option<usize>,
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub stop_sequences: Option<Vec<String>>,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub max_output_tokens: Option<usize>,
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub temperature: Option<f64>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub top_p: Option<f64>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub top_k: Option<usize>,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub thinking_config: Option<ThinkingConfig>,
287}
288
289#[derive(Debug, Serialize, Deserialize)]
290#[serde(rename_all = "camelCase")]
291pub struct SafetySetting {
292    pub category: HarmCategory,
293    pub threshold: HarmBlockThreshold,
294}
295
296#[derive(Debug, Serialize, Deserialize)]
297pub enum HarmCategory {
298    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
299    Unspecified,
300    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
301    Derogatory,
302    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
303    Toxicity,
304    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
305    Violence,
306    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
307    Sexual,
308    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
309    Medical,
310    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
311    Dangerous,
312    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
313    Harassment,
314    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
315    HateSpeech,
316    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
317    SexuallyExplicit,
318    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
319    DangerousContent,
320}
321
322#[derive(Debug, Serialize, Deserialize)]
323#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
324pub enum HarmBlockThreshold {
325    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
326    Unspecified,
327    BlockLowAndAbove,
328    BlockMediumAndAbove,
329    BlockOnlyHigh,
330    BlockNone,
331}
332
333#[derive(Debug, Serialize, Deserialize)]
334#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
335pub enum HarmProbability {
336    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
337    Unspecified,
338    Negligible,
339    Low,
340    Medium,
341    High,
342}
343
344#[derive(Debug, Serialize, Deserialize)]
345#[serde(rename_all = "camelCase")]
346pub struct SafetyRating {
347    pub category: HarmCategory,
348    pub probability: HarmProbability,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352pub struct FunctionCall {
353    pub name: String,
354    pub args: serde_json::Value,
355}
356
357#[derive(Debug, Serialize, Deserialize)]
358pub struct FunctionResponse {
359    pub name: String,
360    pub response: serde_json::Value,
361}
362
363#[derive(Debug, Serialize, Deserialize)]
364#[serde(rename_all = "camelCase")]
365pub struct Tool {
366    pub function_declarations: Vec<FunctionDeclaration>,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370#[serde(rename_all = "camelCase")]
371pub struct ToolConfig {
372    pub function_calling_config: FunctionCallingConfig,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376#[serde(rename_all = "camelCase")]
377pub struct FunctionCallingConfig {
378    pub mode: FunctionCallingMode,
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub allowed_function_names: Option<Vec<String>>,
381}
382
383#[derive(Debug, Serialize, Deserialize)]
384#[serde(rename_all = "lowercase")]
385pub enum FunctionCallingMode {
386    Auto,
387    Any,
388    None,
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct FunctionDeclaration {
393    pub name: String,
394    pub description: String,
395    pub parameters: serde_json::Value,
396}
397
398#[derive(Debug, Default)]
399pub struct ModelName {
400    pub model_id: String,
401}
402
403impl ModelName {
404    pub fn is_empty(&self) -> bool {
405        self.model_id.is_empty()
406    }
407}
408
409const MODEL_NAME_PREFIX: &str = "models/";
410
411impl Serialize for ModelName {
412    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
413    where
414        S: Serializer,
415    {
416        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
417    }
418}
419
420impl<'de> Deserialize<'de> for ModelName {
421    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
422    where
423        D: Deserializer<'de>,
424    {
425        let string = String::deserialize(deserializer)?;
426        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
427            Ok(Self {
428                model_id: id.to_string(),
429            })
430        } else {
431            Err(serde::de::Error::custom(format!(
432                "Expected model name to begin with {}, got: {}",
433                MODEL_NAME_PREFIX, string
434            )))
435        }
436    }
437}
438
439#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
440#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
441pub enum Model {
442    #[serde(
443        rename = "gemini-2.5-flash-lite",
444        alias = "gemini-2.5-flash-lite-preview-06-17",
445        alias = "gemini-2.0-flash-lite-preview"
446    )]
447    Gemini25FlashLite,
448    #[serde(
449        rename = "gemini-2.5-flash",
450        alias = "gemini-2.0-flash-thinking-exp",
451        alias = "gemini-2.5-flash-preview-04-17",
452        alias = "gemini-2.5-flash-preview-05-20",
453        alias = "gemini-2.5-flash-preview-latest",
454        alias = "gemini-2.0-flash"
455    )]
456    #[default]
457    Gemini25Flash,
458    #[serde(
459        rename = "gemini-2.5-pro",
460        alias = "gemini-2.0-pro-exp",
461        alias = "gemini-2.5-pro-preview-latest",
462        alias = "gemini-2.5-pro-exp-03-25",
463        alias = "gemini-2.5-pro-preview-03-25",
464        alias = "gemini-2.5-pro-preview-05-06",
465        alias = "gemini-2.5-pro-preview-06-05"
466    )]
467    Gemini25Pro,
468    #[serde(rename = "gemini-3-flash-preview")]
469    Gemini3Flash,
470    #[serde(rename = "gemini-3.1-pro-preview", alias = "gemini-3-pro-preview")]
471    Gemini31Pro,
472    #[serde(rename = "custom")]
473    Custom {
474        name: String,
475        /// The name displayed in the UI, such as in the agent panel model dropdown menu.
476        display_name: Option<String>,
477        max_tokens: u64,
478        #[serde(default)]
479        mode: GoogleModelMode,
480    },
481}
482
483impl Model {
484    pub fn default_fast() -> Self {
485        Self::Gemini25FlashLite
486    }
487
488    pub fn id(&self) -> &str {
489        match self {
490            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
491            Self::Gemini25Flash => "gemini-2.5-flash",
492            Self::Gemini25Pro => "gemini-2.5-pro",
493            Self::Gemini3Flash => "gemini-3-flash-preview",
494            Self::Gemini31Pro => "gemini-3.1-pro-preview",
495            Self::Custom { name, .. } => name,
496        }
497    }
498    pub fn request_id(&self) -> &str {
499        match self {
500            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
501            Self::Gemini25Flash => "gemini-2.5-flash",
502            Self::Gemini25Pro => "gemini-2.5-pro",
503            Self::Gemini3Flash => "gemini-3-flash-preview",
504            Self::Gemini31Pro => "gemini-3.1-pro-preview",
505            Self::Custom { name, .. } => name,
506        }
507    }
508
509    pub fn display_name(&self) -> &str {
510        match self {
511            Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
512            Self::Gemini25Flash => "Gemini 2.5 Flash",
513            Self::Gemini25Pro => "Gemini 2.5 Pro",
514            Self::Gemini3Flash => "Gemini 3 Flash",
515            Self::Gemini31Pro => "Gemini 3.1 Pro",
516            Self::Custom {
517                name, display_name, ..
518            } => display_name.as_ref().unwrap_or(name),
519        }
520    }
521
522    pub fn max_token_count(&self) -> u64 {
523        match self {
524            Self::Gemini25FlashLite
525            | Self::Gemini25Flash
526            | Self::Gemini25Pro
527            | Self::Gemini3Flash
528            | Self::Gemini31Pro => 1_048_576,
529            Self::Custom { max_tokens, .. } => *max_tokens,
530        }
531    }
532
533    pub fn max_output_tokens(&self) -> Option<u64> {
534        match self {
535            Model::Gemini25FlashLite
536            | Model::Gemini25Flash
537            | Model::Gemini25Pro
538            | Model::Gemini3Flash
539            | Model::Gemini31Pro => Some(65_536),
540            Model::Custom { .. } => None,
541        }
542    }
543
544    pub fn supports_tools(&self) -> bool {
545        true
546    }
547
548    pub fn supports_images(&self) -> bool {
549        true
550    }
551
552    pub fn mode(&self) -> GoogleModelMode {
553        match self {
554            Self::Gemini25FlashLite | Self::Gemini25Flash | Self::Gemini25Pro => {
555                GoogleModelMode::Thinking {
556                    // By default these models are set to "auto", so we preserve that behavior
557                    // but indicate they are capable of thinking mode
558                    budget_tokens: None,
559                }
560            }
561            Self::Gemini3Flash => GoogleModelMode::Default,
562            Self::Gemini31Pro => GoogleModelMode::Thinking {
563                budget_tokens: None,
564            },
565            Self::Custom { mode, .. } => *mode,
566        }
567    }
568}
569
570impl std::fmt::Display for Model {
571    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
572        write!(f, "{}", self.id())
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use serde_json::json;
580
581    #[test]
582    fn test_function_call_part_with_signature_serializes_correctly() {
583        let part = FunctionCallPart {
584            function_call: FunctionCall {
585                name: "test_function".to_string(),
586                args: json!({"arg": "value"}),
587            },
588            thought_signature: Some("test_signature".to_string()),
589        };
590
591        let serialized = serde_json::to_value(&part).unwrap();
592
593        assert_eq!(serialized["functionCall"]["name"], "test_function");
594        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
595        assert_eq!(serialized["thoughtSignature"], "test_signature");
596    }
597
598    #[test]
599    fn test_function_call_part_without_signature_omits_field() {
600        let part = FunctionCallPart {
601            function_call: FunctionCall {
602                name: "test_function".to_string(),
603                args: json!({"arg": "value"}),
604            },
605            thought_signature: None,
606        };
607
608        let serialized = serde_json::to_value(&part).unwrap();
609
610        assert_eq!(serialized["functionCall"]["name"], "test_function");
611        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
612        // thoughtSignature field should not be present when None
613        assert!(serialized.get("thoughtSignature").is_none());
614    }
615
616    #[test]
617    fn test_function_call_part_deserializes_with_signature() {
618        let json = json!({
619            "functionCall": {
620                "name": "test_function",
621                "args": {"arg": "value"}
622            },
623            "thoughtSignature": "test_signature"
624        });
625
626        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
627
628        assert_eq!(part.function_call.name, "test_function");
629        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
630    }
631
632    #[test]
633    fn test_function_call_part_deserializes_without_signature() {
634        let json = json!({
635            "functionCall": {
636                "name": "test_function",
637                "args": {"arg": "value"}
638            }
639        });
640
641        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
642
643        assert_eq!(part.function_call.name, "test_function");
644        assert_eq!(part.thought_signature, None);
645    }
646
647    #[test]
648    fn test_function_call_part_round_trip() {
649        let original = FunctionCallPart {
650            function_call: FunctionCall {
651                name: "test_function".to_string(),
652                args: json!({"arg": "value", "nested": {"key": "val"}}),
653            },
654            thought_signature: Some("round_trip_signature".to_string()),
655        };
656
657        let serialized = serde_json::to_value(&original).unwrap();
658        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
659
660        assert_eq!(deserialized.function_call.name, original.function_call.name);
661        assert_eq!(deserialized.function_call.args, original.function_call.args);
662        assert_eq!(deserialized.thought_signature, original.thought_signature);
663    }
664
665    #[test]
666    fn test_function_call_part_with_empty_signature_serializes() {
667        let part = FunctionCallPart {
668            function_call: FunctionCall {
669                name: "test_function".to_string(),
670                args: json!({"arg": "value"}),
671            },
672            thought_signature: Some("".to_string()),
673        };
674
675        let serialized = serde_json::to_value(&part).unwrap();
676
677        // Empty string should still be serialized (normalization happens at a higher level)
678        assert_eq!(serialized["thoughtSignature"], "");
679    }
680}