1use std::mem;
2
3use anyhow::{Result, anyhow, bail};
4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7pub use settings::ModelMode as GoogleModelMode;
8
9pub const API_URL: &str = "https://generativelanguage.googleapis.com";
10
11pub async fn stream_generate_content(
12 client: &dyn HttpClient,
13 api_url: &str,
14 api_key: &str,
15 mut request: GenerateContentRequest,
16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
17 let api_key = api_key.trim();
18 validate_generate_content_request(&request)?;
19
20 // The `model` field is emptied as it is provided as a path parameter.
21 let model_id = mem::take(&mut request.model.model_id);
22
23 let uri =
24 format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
25
26 let request_builder = HttpRequest::builder()
27 .method(Method::POST)
28 .uri(uri)
29 .header("Content-Type", "application/json");
30
31 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
32 let mut response = client.send(request).await?;
33 if response.status().is_success() {
34 let reader = BufReader::new(response.into_body());
35 Ok(reader
36 .lines()
37 .filter_map(|line| async move {
38 match line {
39 Ok(line) => {
40 if let Some(line) = line.strip_prefix("data: ") {
41 match serde_json::from_str(line) {
42 Ok(response) => Some(Ok(response)),
43 Err(error) => Some(Err(anyhow!(format!(
44 "Error parsing JSON: {error:?}\n{line:?}"
45 )))),
46 }
47 } else {
48 None
49 }
50 }
51 Err(error) => Some(Err(anyhow!(error))),
52 }
53 })
54 .boxed())
55 } else {
56 let mut text = String::new();
57 response.body_mut().read_to_string(&mut text).await?;
58 Err(anyhow!(
59 "error during streamGenerateContent, status code: {:?}, body: {}",
60 response.status(),
61 text
62 ))
63 }
64}
65
66pub async fn count_tokens(
67 client: &dyn HttpClient,
68 api_url: &str,
69 api_key: &str,
70 request: CountTokensRequest,
71) -> Result<CountTokensResponse> {
72 validate_generate_content_request(&request.generate_content_request)?;
73
74 let uri = format!(
75 "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
76 model_id = &request.generate_content_request.model.model_id,
77 );
78
79 let request = serde_json::to_string(&request)?;
80 let request_builder = HttpRequest::builder()
81 .method(Method::POST)
82 .uri(&uri)
83 .header("Content-Type", "application/json");
84 let http_request = request_builder.body(AsyncBody::from(request))?;
85
86 let mut response = client.send(http_request).await?;
87 let mut text = String::new();
88 response.body_mut().read_to_string(&mut text).await?;
89 anyhow::ensure!(
90 response.status().is_success(),
91 "error during countTokens, status code: {:?}, body: {}",
92 response.status(),
93 text
94 );
95 Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
96}
97
98pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
99 if request.model.is_empty() {
100 bail!("Model must be specified");
101 }
102
103 if request.contents.is_empty() {
104 bail!("Request must contain at least one content item");
105 }
106
107 if let Some(user_content) = request
108 .contents
109 .iter()
110 .find(|content| content.role == Role::User)
111 && user_content.parts.is_empty()
112 {
113 bail!("User content must contain at least one part");
114 }
115
116 Ok(())
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120pub enum Task {
121 #[serde(rename = "generateContent")]
122 GenerateContent,
123 #[serde(rename = "streamGenerateContent")]
124 StreamGenerateContent,
125 #[serde(rename = "countTokens")]
126 CountTokens,
127 #[serde(rename = "embedContent")]
128 EmbedContent,
129 #[serde(rename = "batchEmbedContents")]
130 BatchEmbedContents,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct GenerateContentRequest {
136 #[serde(default, skip_serializing_if = "ModelName::is_empty")]
137 pub model: ModelName,
138 pub contents: Vec<Content>,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub system_instruction: Option<SystemInstruction>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub generation_config: Option<GenerationConfig>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub safety_settings: Option<Vec<SafetySetting>>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub tools: Option<Vec<Tool>>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub tool_config: Option<ToolConfig>,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct GenerateContentResponse {
154 #[serde(skip_serializing_if = "Option::is_none")]
155 pub candidates: Option<Vec<GenerateContentCandidate>>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub prompt_feedback: Option<PromptFeedback>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub usage_metadata: Option<UsageMetadata>,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct GenerateContentCandidate {
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub index: Option<usize>,
167 pub content: Content,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub finish_reason: Option<String>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub finish_message: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub safety_ratings: Option<Vec<SafetyRating>>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub citation_metadata: Option<CitationMetadata>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct Content {
181 #[serde(default)]
182 pub parts: Vec<Part>,
183 pub role: Role,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct SystemInstruction {
189 pub parts: Vec<Part>,
190}
191
192#[derive(Debug, PartialEq, Deserialize, Serialize)]
193#[serde(rename_all = "camelCase")]
194pub enum Role {
195 User,
196 Model,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200#[serde(untagged)]
201pub enum Part {
202 TextPart(TextPart),
203 InlineDataPart(InlineDataPart),
204 FunctionCallPart(FunctionCallPart),
205 FunctionResponsePart(FunctionResponsePart),
206 ThoughtPart(ThoughtPart),
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "camelCase")]
211pub struct TextPart {
212 pub text: String,
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct InlineDataPart {
218 pub inline_data: GenerativeContentBlob,
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GenerativeContentBlob {
224 pub mime_type: String,
225 pub data: String,
226}
227
228#[derive(Debug, Serialize, Deserialize)]
229#[serde(rename_all = "camelCase")]
230pub struct FunctionCallPart {
231 pub function_call: FunctionCall,
232}
233
234#[derive(Debug, Serialize, Deserialize)]
235#[serde(rename_all = "camelCase")]
236pub struct FunctionResponsePart {
237 pub function_response: FunctionResponse,
238}
239
240#[derive(Debug, Serialize, Deserialize)]
241#[serde(rename_all = "camelCase")]
242pub struct ThoughtPart {
243 pub thought: bool,
244 pub thought_signature: String,
245}
246
247#[derive(Debug, Serialize, Deserialize)]
248#[serde(rename_all = "camelCase")]
249pub struct CitationSource {
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub start_index: Option<usize>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 pub end_index: Option<usize>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 pub uri: Option<String>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub license: Option<String>,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261#[serde(rename_all = "camelCase")]
262pub struct CitationMetadata {
263 pub citation_sources: Vec<CitationSource>,
264}
265
266#[derive(Debug, Serialize, Deserialize)]
267#[serde(rename_all = "camelCase")]
268pub struct PromptFeedback {
269 #[serde(skip_serializing_if = "Option::is_none")]
270 pub block_reason: Option<String>,
271 pub safety_ratings: Option<Vec<SafetyRating>>,
272 #[serde(skip_serializing_if = "Option::is_none")]
273 pub block_reason_message: Option<String>,
274}
275
276#[derive(Debug, Serialize, Deserialize, Default)]
277#[serde(rename_all = "camelCase")]
278pub struct UsageMetadata {
279 #[serde(skip_serializing_if = "Option::is_none")]
280 pub prompt_token_count: Option<u64>,
281 #[serde(skip_serializing_if = "Option::is_none")]
282 pub cached_content_token_count: Option<u64>,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub candidates_token_count: Option<u64>,
285 #[serde(skip_serializing_if = "Option::is_none")]
286 pub tool_use_prompt_token_count: Option<u64>,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 pub thoughts_token_count: Option<u64>,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 pub total_token_count: Option<u64>,
291}
292
293#[derive(Debug, Serialize, Deserialize)]
294#[serde(rename_all = "camelCase")]
295pub struct ThinkingConfig {
296 pub thinking_budget: u32,
297}
298
299#[derive(Debug, Deserialize, Serialize)]
300#[serde(rename_all = "camelCase")]
301pub struct GenerationConfig {
302 #[serde(skip_serializing_if = "Option::is_none")]
303 pub candidate_count: Option<usize>,
304 #[serde(skip_serializing_if = "Option::is_none")]
305 pub stop_sequences: Option<Vec<String>>,
306 #[serde(skip_serializing_if = "Option::is_none")]
307 pub max_output_tokens: Option<usize>,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub temperature: Option<f64>,
310 #[serde(skip_serializing_if = "Option::is_none")]
311 pub top_p: Option<f64>,
312 #[serde(skip_serializing_if = "Option::is_none")]
313 pub top_k: Option<usize>,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 pub thinking_config: Option<ThinkingConfig>,
316}
317
318#[derive(Debug, Serialize, Deserialize)]
319#[serde(rename_all = "camelCase")]
320pub struct SafetySetting {
321 pub category: HarmCategory,
322 pub threshold: HarmBlockThreshold,
323}
324
325#[derive(Debug, Serialize, Deserialize)]
326pub enum HarmCategory {
327 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
328 Unspecified,
329 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
330 Derogatory,
331 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
332 Toxicity,
333 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
334 Violence,
335 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
336 Sexual,
337 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
338 Medical,
339 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
340 Dangerous,
341 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
342 Harassment,
343 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
344 HateSpeech,
345 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
346 SexuallyExplicit,
347 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
348 DangerousContent,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
353pub enum HarmBlockThreshold {
354 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
355 Unspecified,
356 BlockLowAndAbove,
357 BlockMediumAndAbove,
358 BlockOnlyHigh,
359 BlockNone,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
364pub enum HarmProbability {
365 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
366 Unspecified,
367 Negligible,
368 Low,
369 Medium,
370 High,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374#[serde(rename_all = "camelCase")]
375pub struct SafetyRating {
376 pub category: HarmCategory,
377 pub probability: HarmProbability,
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381#[serde(rename_all = "camelCase")]
382pub struct CountTokensRequest {
383 pub generate_content_request: GenerateContentRequest,
384}
385
386#[derive(Debug, Serialize, Deserialize)]
387#[serde(rename_all = "camelCase")]
388pub struct CountTokensResponse {
389 pub total_tokens: u64,
390}
391
392#[derive(Debug, Serialize, Deserialize)]
393pub struct FunctionCall {
394 pub name: String,
395 pub args: serde_json::Value,
396}
397
398#[derive(Debug, Serialize, Deserialize)]
399pub struct FunctionResponse {
400 pub name: String,
401 pub response: serde_json::Value,
402}
403
404#[derive(Debug, Serialize, Deserialize)]
405#[serde(rename_all = "camelCase")]
406pub struct Tool {
407 pub function_declarations: Vec<FunctionDeclaration>,
408}
409
410#[derive(Debug, Serialize, Deserialize)]
411#[serde(rename_all = "camelCase")]
412pub struct ToolConfig {
413 pub function_calling_config: FunctionCallingConfig,
414}
415
416#[derive(Debug, Serialize, Deserialize)]
417#[serde(rename_all = "camelCase")]
418pub struct FunctionCallingConfig {
419 pub mode: FunctionCallingMode,
420 #[serde(skip_serializing_if = "Option::is_none")]
421 pub allowed_function_names: Option<Vec<String>>,
422}
423
424#[derive(Debug, Serialize, Deserialize)]
425#[serde(rename_all = "lowercase")]
426pub enum FunctionCallingMode {
427 Auto,
428 Any,
429 None,
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub struct FunctionDeclaration {
434 pub name: String,
435 pub description: String,
436 pub parameters: serde_json::Value,
437}
438
439#[derive(Debug, Default)]
440pub struct ModelName {
441 pub model_id: String,
442}
443
444impl ModelName {
445 pub fn is_empty(&self) -> bool {
446 self.model_id.is_empty()
447 }
448}
449
450const MODEL_NAME_PREFIX: &str = "models/";
451
452impl Serialize for ModelName {
453 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
454 where
455 S: Serializer,
456 {
457 serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
458 }
459}
460
461impl<'de> Deserialize<'de> for ModelName {
462 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
463 where
464 D: Deserializer<'de>,
465 {
466 let string = String::deserialize(deserializer)?;
467 if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
468 Ok(Self {
469 model_id: id.to_string(),
470 })
471 } else {
472 Err(serde::de::Error::custom(format!(
473 "Expected model name to begin with {}, got: {}",
474 MODEL_NAME_PREFIX, string
475 )))
476 }
477 }
478}
479
480#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
481#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
482pub enum Model {
483 #[serde(rename = "gemini-1.5-pro")]
484 Gemini15Pro,
485 #[serde(rename = "gemini-1.5-flash-8b")]
486 Gemini15Flash8b,
487 #[serde(rename = "gemini-1.5-flash")]
488 Gemini15Flash,
489 #[serde(
490 rename = "gemini-2.0-flash-lite",
491 alias = "gemini-2.0-flash-lite-preview"
492 )]
493 Gemini20FlashLite,
494 #[serde(rename = "gemini-2.0-flash")]
495 Gemini20Flash,
496 #[serde(
497 rename = "gemini-2.5-flash-lite-preview",
498 alias = "gemini-2.5-flash-lite-preview-06-17"
499 )]
500 Gemini25FlashLitePreview,
501 #[serde(
502 rename = "gemini-2.5-flash",
503 alias = "gemini-2.0-flash-thinking-exp",
504 alias = "gemini-2.5-flash-preview-04-17",
505 alias = "gemini-2.5-flash-preview-05-20",
506 alias = "gemini-2.5-flash-preview-latest"
507 )]
508 #[default]
509 Gemini25Flash,
510 #[serde(
511 rename = "gemini-2.5-pro",
512 alias = "gemini-2.0-pro-exp",
513 alias = "gemini-2.5-pro-preview-latest",
514 alias = "gemini-2.5-pro-exp-03-25",
515 alias = "gemini-2.5-pro-preview-03-25",
516 alias = "gemini-2.5-pro-preview-05-06",
517 alias = "gemini-2.5-pro-preview-06-05"
518 )]
519 Gemini25Pro,
520 #[serde(rename = "custom")]
521 Custom {
522 name: String,
523 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
524 display_name: Option<String>,
525 max_tokens: u64,
526 #[serde(default)]
527 mode: GoogleModelMode,
528 },
529}
530
531impl Model {
532 pub fn default_fast() -> Self {
533 Self::Gemini20FlashLite
534 }
535
536 pub fn id(&self) -> &str {
537 match self {
538 Self::Gemini15Pro => "gemini-1.5-pro",
539 Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
540 Self::Gemini15Flash => "gemini-1.5-flash",
541 Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
542 Self::Gemini20Flash => "gemini-2.0-flash",
543 Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview",
544 Self::Gemini25Flash => "gemini-2.5-flash",
545 Self::Gemini25Pro => "gemini-2.5-pro",
546 Self::Custom { name, .. } => name,
547 }
548 }
549 pub fn request_id(&self) -> &str {
550 match self {
551 Self::Gemini15Pro => "gemini-1.5-pro",
552 Self::Gemini15Flash8b => "gemini-1.5-flash-8b",
553 Self::Gemini15Flash => "gemini-1.5-flash",
554 Self::Gemini20FlashLite => "gemini-2.0-flash-lite",
555 Self::Gemini20Flash => "gemini-2.0-flash",
556 Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview-06-17",
557 Self::Gemini25Flash => "gemini-2.5-flash",
558 Self::Gemini25Pro => "gemini-2.5-pro",
559 Self::Custom { name, .. } => name,
560 }
561 }
562
563 pub fn display_name(&self) -> &str {
564 match self {
565 Self::Gemini15Pro => "Gemini 1.5 Pro",
566 Self::Gemini15Flash8b => "Gemini 1.5 Flash-8b",
567 Self::Gemini15Flash => "Gemini 1.5 Flash",
568 Self::Gemini20FlashLite => "Gemini 2.0 Flash-Lite",
569 Self::Gemini20Flash => "Gemini 2.0 Flash",
570 Self::Gemini25FlashLitePreview => "Gemini 2.5 Flash-Lite Preview",
571 Self::Gemini25Flash => "Gemini 2.5 Flash",
572 Self::Gemini25Pro => "Gemini 2.5 Pro",
573 Self::Custom {
574 name, display_name, ..
575 } => display_name.as_ref().unwrap_or(name),
576 }
577 }
578
579 pub fn max_token_count(&self) -> u64 {
580 match self {
581 Self::Gemini15Pro => 2_097_152,
582 Self::Gemini15Flash8b => 1_048_576,
583 Self::Gemini15Flash => 1_048_576,
584 Self::Gemini20FlashLite => 1_048_576,
585 Self::Gemini20Flash => 1_048_576,
586 Self::Gemini25FlashLitePreview => 1_000_000,
587 Self::Gemini25Flash => 1_048_576,
588 Self::Gemini25Pro => 1_048_576,
589 Self::Custom { max_tokens, .. } => *max_tokens,
590 }
591 }
592
593 pub fn max_output_tokens(&self) -> Option<u64> {
594 match self {
595 Model::Gemini15Pro => Some(8_192),
596 Model::Gemini15Flash8b => Some(8_192),
597 Model::Gemini15Flash => Some(8_192),
598 Model::Gemini20FlashLite => Some(8_192),
599 Model::Gemini20Flash => Some(8_192),
600 Model::Gemini25FlashLitePreview => Some(64_000),
601 Model::Gemini25Flash => Some(65_536),
602 Model::Gemini25Pro => Some(65_536),
603 Model::Custom { .. } => None,
604 }
605 }
606
607 pub fn supports_tools(&self) -> bool {
608 true
609 }
610
611 pub fn supports_images(&self) -> bool {
612 true
613 }
614
615 pub fn mode(&self) -> GoogleModelMode {
616 match self {
617 Self::Gemini15Pro
618 | Self::Gemini15Flash8b
619 | Self::Gemini15Flash
620 | Self::Gemini20FlashLite
621 | Self::Gemini20Flash => GoogleModelMode::Default,
622 Self::Gemini25FlashLitePreview | Self::Gemini25Flash | Self::Gemini25Pro => {
623 GoogleModelMode::Thinking {
624 // By default these models are set to "auto", so we preserve that behavior
625 // but indicate they are capable of thinking mode
626 budget_tokens: None,
627 }
628 }
629 Self::Custom { mode, .. } => *mode,
630 }
631 }
632}
633
634impl std::fmt::Display for Model {
635 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636 write!(f, "{}", self.id())
637 }
638}