google_ai.rs

  1use std::mem;
  2
  3use anyhow::{Result, anyhow, bail};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use serde::{Deserialize, Deserializer, Serialize, Serializer};
  7pub use settings::ModelMode as GoogleModelMode;
  8
  9pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 10
 11pub async fn stream_generate_content(
 12    client: &dyn HttpClient,
 13    api_url: &str,
 14    api_key: &str,
 15    mut request: GenerateContentRequest,
 16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 17    let api_key = api_key.trim();
 18    validate_generate_content_request(&request)?;
 19
 20    // The `model` field is emptied as it is provided as a path parameter.
 21    let model_id = mem::take(&mut request.model.model_id);
 22
 23    let uri =
 24        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
 25
 26    let request_builder = HttpRequest::builder()
 27        .method(Method::POST)
 28        .uri(uri)
 29        .header("Content-Type", "application/json");
 30
 31    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
 32    let mut response = client.send(request).await?;
 33    if response.status().is_success() {
 34        let reader = BufReader::new(response.into_body());
 35        Ok(reader
 36            .lines()
 37            .filter_map(|line| async move {
 38                match line {
 39                    Ok(line) => {
 40                        if let Some(line) = line.strip_prefix("data: ") {
 41                            match serde_json::from_str(line) {
 42                                Ok(response) => Some(Ok(response)),
 43                                Err(error) => Some(Err(anyhow!(format!(
 44                                    "Error parsing JSON: {error:?}\n{line:?}"
 45                                )))),
 46                            }
 47                        } else {
 48                            None
 49                        }
 50                    }
 51                    Err(error) => Some(Err(anyhow!(error))),
 52                }
 53            })
 54            .boxed())
 55    } else {
 56        let mut text = String::new();
 57        response.body_mut().read_to_string(&mut text).await?;
 58        Err(anyhow!(
 59            "error during streamGenerateContent, status code: {:?}, body: {}",
 60            response.status(),
 61            text
 62        ))
 63    }
 64}
 65
 66pub async fn count_tokens(
 67    client: &dyn HttpClient,
 68    api_url: &str,
 69    api_key: &str,
 70    request: CountTokensRequest,
 71) -> Result<CountTokensResponse> {
 72    validate_generate_content_request(&request.generate_content_request)?;
 73
 74    let uri = format!(
 75        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
 76        model_id = &request.generate_content_request.model.model_id,
 77    );
 78
 79    let request = serde_json::to_string(&request)?;
 80    let request_builder = HttpRequest::builder()
 81        .method(Method::POST)
 82        .uri(&uri)
 83        .header("Content-Type", "application/json");
 84    let http_request = request_builder.body(AsyncBody::from(request))?;
 85
 86    let mut response = client.send(http_request).await?;
 87    let mut text = String::new();
 88    response.body_mut().read_to_string(&mut text).await?;
 89    anyhow::ensure!(
 90        response.status().is_success(),
 91        "error during countTokens, status code: {:?}, body: {}",
 92        response.status(),
 93        text
 94    );
 95    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 96}
 97
 98pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
 99    if request.model.is_empty() {
100        bail!("Model must be specified");
101    }
102
103    if request.contents.is_empty() {
104        bail!("Request must contain at least one content item");
105    }
106
107    if let Some(user_content) = request
108        .contents
109        .iter()
110        .find(|content| content.role == Role::User)
111        && user_content.parts.is_empty()
112    {
113        bail!("User content must contain at least one part");
114    }
115
116    Ok(())
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub enum Task {
121    #[serde(rename = "generateContent")]
122    GenerateContent,
123    #[serde(rename = "streamGenerateContent")]
124    StreamGenerateContent,
125    #[serde(rename = "countTokens")]
126    CountTokens,
127    #[serde(rename = "embedContent")]
128    EmbedContent,
129    #[serde(rename = "batchEmbedContents")]
130    BatchEmbedContents,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct GenerateContentRequest {
136    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
137    pub model: ModelName,
138    pub contents: Vec<Content>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub system_instruction: Option<SystemInstruction>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub generation_config: Option<GenerationConfig>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub safety_settings: Option<Vec<SafetySetting>>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub tools: Option<Vec<Tool>>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub tool_config: Option<ToolConfig>,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct GenerateContentResponse {
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub candidates: Option<Vec<GenerateContentCandidate>>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub prompt_feedback: Option<PromptFeedback>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub usage_metadata: Option<UsageMetadata>,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct GenerateContentCandidate {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub index: Option<usize>,
167    pub content: Content,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub finish_reason: Option<String>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub finish_message: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub safety_ratings: Option<Vec<SafetyRating>>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub citation_metadata: Option<CitationMetadata>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct Content {
181    #[serde(default)]
182    pub parts: Vec<Part>,
183    pub role: Role,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct SystemInstruction {
189    pub parts: Vec<Part>,
190}
191
192#[derive(Debug, PartialEq, Deserialize, Serialize)]
193#[serde(rename_all = "camelCase")]
194pub enum Role {
195    User,
196    Model,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum Part {
202    TextPart(TextPart),
203    InlineDataPart(InlineDataPart),
204    FunctionCallPart(FunctionCallPart),
205    FunctionResponsePart(FunctionResponsePart),
206    ThoughtPart(ThoughtPart),
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "camelCase")]
211pub struct TextPart {
212    pub text: String,
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct InlineDataPart {
218    pub inline_data: GenerativeContentBlob,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GenerativeContentBlob {
224    pub mime_type: String,
225    pub data: String,
226}
227
228#[derive(Debug, Serialize, Deserialize)]
229#[serde(rename_all = "camelCase")]
230pub struct FunctionCallPart {
231    pub function_call: FunctionCall,
232}
233
234#[derive(Debug, Serialize, Deserialize)]
235#[serde(rename_all = "camelCase")]
236pub struct FunctionResponsePart {
237    pub function_response: FunctionResponse,
238}
239
240#[derive(Debug, Serialize, Deserialize)]
241#[serde(rename_all = "camelCase")]
242pub struct ThoughtPart {
243    pub thought: bool,
244    pub thought_signature: String,
245}
246
247#[derive(Debug, Serialize, Deserialize)]
248#[serde(rename_all = "camelCase")]
249pub struct CitationSource {
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub start_index: Option<usize>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub end_index: Option<usize>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub uri: Option<String>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub license: Option<String>,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261#[serde(rename_all = "camelCase")]
262pub struct CitationMetadata {
263    pub citation_sources: Vec<CitationSource>,
264}
265
266#[derive(Debug, Serialize, Deserialize)]
267#[serde(rename_all = "camelCase")]
268pub struct PromptFeedback {
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub block_reason: Option<String>,
271    pub safety_ratings: Option<Vec<SafetyRating>>,
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub block_reason_message: Option<String>,
274}
275
276#[derive(Debug, Serialize, Deserialize, Default)]
277#[serde(rename_all = "camelCase")]
278pub struct UsageMetadata {
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub prompt_token_count: Option<u64>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub cached_content_token_count: Option<u64>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub candidates_token_count: Option<u64>,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub tool_use_prompt_token_count: Option<u64>,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub thoughts_token_count: Option<u64>,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub total_token_count: Option<u64>,
291}
292
293#[derive(Debug, Serialize, Deserialize)]
294#[serde(rename_all = "camelCase")]
295pub struct ThinkingConfig {
296    pub thinking_budget: u32,
297}
298
299#[derive(Debug, Deserialize, Serialize)]
300#[serde(rename_all = "camelCase")]
301pub struct GenerationConfig {
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub candidate_count: Option<usize>,
304    #[serde(skip_serializing_if = "Option::is_none")]
305    pub stop_sequences: Option<Vec<String>>,
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub max_output_tokens: Option<usize>,
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub temperature: Option<f64>,
310    #[serde(skip_serializing_if = "Option::is_none")]
311    pub top_p: Option<f64>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub top_k: Option<usize>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub thinking_config: Option<ThinkingConfig>,
316}
317
318#[derive(Debug, Serialize, Deserialize)]
319#[serde(rename_all = "camelCase")]
320pub struct SafetySetting {
321    pub category: HarmCategory,
322    pub threshold: HarmBlockThreshold,
323}
324
325#[derive(Debug, Serialize, Deserialize)]
326pub enum HarmCategory {
327    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
328    Unspecified,
329    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
330    Derogatory,
331    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
332    Toxicity,
333    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
334    Violence,
335    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
336    Sexual,
337    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
338    Medical,
339    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
340    Dangerous,
341    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
342    Harassment,
343    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
344    HateSpeech,
345    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
346    SexuallyExplicit,
347    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
348    DangerousContent,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
353pub enum HarmBlockThreshold {
354    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
355    Unspecified,
356    BlockLowAndAbove,
357    BlockMediumAndAbove,
358    BlockOnlyHigh,
359    BlockNone,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
364pub enum HarmProbability {
365    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
366    Unspecified,
367    Negligible,
368    Low,
369    Medium,
370    High,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374#[serde(rename_all = "camelCase")]
375pub struct SafetyRating {
376    pub category: HarmCategory,
377    pub probability: HarmProbability,
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381#[serde(rename_all = "camelCase")]
382pub struct CountTokensRequest {
383    pub generate_content_request: GenerateContentRequest,
384}
385
386#[derive(Debug, Serialize, Deserialize)]
387#[serde(rename_all = "camelCase")]
388pub struct CountTokensResponse {
389    pub total_tokens: u64,
390}
391
392#[derive(Debug, Serialize, Deserialize)]
393pub struct FunctionCall {
394    pub name: String,
395    pub args: serde_json::Value,
396}
397
398#[derive(Debug, Serialize, Deserialize)]
399pub struct FunctionResponse {
400    pub name: String,
401    pub response: serde_json::Value,
402}
403
404#[derive(Debug, Serialize, Deserialize)]
405#[serde(rename_all = "camelCase")]
406pub struct Tool {
407    pub function_declarations: Vec<FunctionDeclaration>,
408}
409
410#[derive(Debug, Serialize, Deserialize)]
411#[serde(rename_all = "camelCase")]
412pub struct ToolConfig {
413    pub function_calling_config: FunctionCallingConfig,
414}
415
416#[derive(Debug, Serialize, Deserialize)]
417#[serde(rename_all = "camelCase")]
418pub struct FunctionCallingConfig {
419    pub mode: FunctionCallingMode,
420    #[serde(skip_serializing_if = "Option::is_none")]
421    pub allowed_function_names: Option<Vec<String>>,
422}
423
424#[derive(Debug, Serialize, Deserialize)]
425#[serde(rename_all = "lowercase")]
426pub enum FunctionCallingMode {
427    Auto,
428    Any,
429    None,
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub struct FunctionDeclaration {
434    pub name: String,
435    pub description: String,
436    pub parameters: serde_json::Value,
437}
438
439#[derive(Debug, Default)]
440pub struct ModelName {
441    pub model_id: String,
442}
443
444impl ModelName {
445    pub fn is_empty(&self) -> bool {
446        self.model_id.is_empty()
447    }
448}
449
450const MODEL_NAME_PREFIX: &str = "models/";
451
452impl Serialize for ModelName {
453    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
454    where
455        S: Serializer,
456    {
457        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
458    }
459}
460
461impl<'de> Deserialize<'de> for ModelName {
462    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
463    where
464        D: Deserializer<'de>,
465    {
466        let string = String::deserialize(deserializer)?;
467        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
468            Ok(Self {
469                model_id: id.to_string(),
470            })
471        } else {
472            Err(serde::de::Error::custom(format!(
473                "Expected model name to begin with {}, got: {}",
474                MODEL_NAME_PREFIX, string
475            )))
476        }
477    }
478}
479
480#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
481#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
482pub enum Model {
483    #[serde(rename = "gemini-1.5-pro")]
484    Gemini15Pro,
485    #[serde(rename = "gemini-1.5-flash-8b")]
486    Gemini15Flash8b,
487    #[serde(rename = "gemini-1.5-flash")]
488    Gemini15Flash,
489    #[serde(
490        rename = "gemini-2.0-flash-lite",
491        alias = "gemini-2.0-flash-lite-preview"
492    )]
493    Gemini20FlashLite,
494    #[serde(rename = "gemini-2.0-flash")]
495    Gemini20Flash,
496    #[serde(
497        rename = "gemini-2.5-flash-lite-preview",
498        alias = "gemini-2.5-flash-lite-preview-06-17"
499    )]
500    Gemini25FlashLitePreview,
501    #[serde(
502        rename = "gemini-2.5-flash",
503        alias = "gemini-2.0-flash-thinking-exp",
504        alias = "gemini-2.5-flash-preview-04-17",
505        alias = "gemini-2.5-flash-preview-05-20",
506        alias = "gemini-2.5-flash-preview-latest"
507    )]
508    #[default]
509    Gemini25Flash,
510    #[serde(
511        rename = "gemini-2.5-pro",
512        alias = "gemini-2.0-pro-exp",
513        alias = "gemini-2.5-pro-preview-latest",
514        alias = "gemini-2.5-pro-exp-03-25",
515        alias = "gemini-2.5-pro-preview-03-25",
516        alias = "gemini-2.5-pro-preview-05-06",
517        alias = "gemini-2.5-pro-preview-06-05"
518    )]
519    Gemini25Pro,
520    #[serde(rename = "custom")]
521    Custom {
522        name: String,
523        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
524        display_name: Option<String>,
525        max_tokens: u64,
526        #[serde(default)]
527        mode: GoogleModelMode,
528    },
529}
530
531impl Model {
532    pub fn default_fast() -> Self {
533        Self::Gemini20FlashLite
534    }
535
536    pub fn id(&self) -> &str {
537        match self {
538            Self::Gemini15Pro => "gemini-1.5-pro",
539            Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
540            Self::Gemini15Flash => "gemini-1.5-flash",
541            Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
542            Self::Gemini20Flash => "gemini-2.0-flash",
543            Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview",
544            Self::Gemini25Flash => "gemini-2.5-flash",
545            Self::Gemini25Pro => "gemini-2.5-pro",
546            Self::Custom { name, .. } => name,
547        }
548    }
549    pub fn request_id(&self) -> &str {
550        match self {
551            Self::Gemini15Pro => "gemini-1.5-pro",
552            Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
553            Self::Gemini15Flash => "gemini-1.5-flash",
554            Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
555            Self::Gemini20Flash => "gemini-2.0-flash",
556            Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview-06-17",
557            Self::Gemini25Flash => "gemini-2.5-flash",
558            Self::Gemini25Pro => "gemini-2.5-pro",
559            Self::Custom { name, .. } => name,
560        }
561    }
562
563    pub fn display_name(&self) -> &str {
564        match self {
565            Self::Gemini15Pro => "Gemini 1.5 Pro",
566            Self::Gemini15Flash8b => "Gemini 1.5 Flash-8b",
567            Self::Gemini15Flash => "Gemini 1.5 Flash",
568            Self::Gemini20FlashLite => "Gemini 2.0 Flash-Lite",
569            Self::Gemini20Flash => "Gemini 2.0 Flash",
570            Self::Gemini25FlashLitePreview => "Gemini 2.5 Flash-Lite Preview",
571            Self::Gemini25Flash => "Gemini 2.5 Flash",
572            Self::Gemini25Pro => "Gemini 2.5 Pro",
573            Self::Custom {
574                name, display_name, ..
575            } => display_name.as_ref().unwrap_or(name),
576        }
577    }
578
579    pub fn max_token_count(&self) -> u64 {
580        match self {
581            Self::Gemini15Pro => 2_097_152,
582            Self::Gemini15Flash8b => 1_048_576,
583            Self::Gemini15Flash => 1_048_576,
584            Self::Gemini20FlashLite => 1_048_576,
585            Self::Gemini20Flash => 1_048_576,
586            Self::Gemini25FlashLitePreview => 1_000_000,
587            Self::Gemini25Flash => 1_048_576,
588            Self::Gemini25Pro => 1_048_576,
589            Self::Custom { max_tokens, .. } => *max_tokens,
590        }
591    }
592
593    pub fn max_output_tokens(&self) -> Option<u64> {
594        match self {
595            Model::Gemini15Pro => Some(8_192),
596            Model::Gemini15Flash8b => Some(8_192),
597            Model::Gemini15Flash => Some(8_192),
598            Model::Gemini20FlashLite => Some(8_192),
599            Model::Gemini20Flash => Some(8_192),
600            Model::Gemini25FlashLitePreview => Some(64_000),
601            Model::Gemini25Flash => Some(65_536),
602            Model::Gemini25Pro => Some(65_536),
603            Model::Custom { .. } => None,
604        }
605    }
606
607    pub fn supports_tools(&self) -> bool {
608        true
609    }
610
611    pub fn supports_images(&self) -> bool {
612        true
613    }
614
615    pub fn mode(&self) -> GoogleModelMode {
616        match self {
617            Self::Gemini15Pro
618            | Self::Gemini15Flash8b
619            | Self::Gemini15Flash
620            | Self::Gemini20FlashLite
621            | Self::Gemini20Flash => GoogleModelMode::Default,
622            Self::Gemini25FlashLitePreview | Self::Gemini25Flash | Self::Gemini25Pro => {
623                GoogleModelMode::Thinking {
624                    // By default these models are set to "auto", so we preserve that behavior
625                    // but indicate they are capable of thinking mode
626                    budget_tokens: None,
627                }
628            }
629            Self::Custom { mode, .. } => *mode,
630        }
631    }
632}
633
634impl std::fmt::Display for Model {
635    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        write!(f, "{}", self.id())
637    }
638}