1use std::mem;
2
3use anyhow::{Result, anyhow, bail};
4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6pub use language_model_core::ModelMode as GoogleModelMode;
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8pub mod completion;
9
10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
11
12pub async fn stream_generate_content(
13 client: &dyn HttpClient,
14 api_url: &str,
15 api_key: &str,
16 mut request: GenerateContentRequest,
17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
18 let api_key = api_key.trim();
19 validate_generate_content_request(&request)?;
20
21 // The `model` field is emptied as it is provided as a path parameter.
22 let model_id = mem::take(&mut request.model.model_id);
23
24 let uri =
25 format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
26
27 let request_builder = HttpRequest::builder()
28 .method(Method::POST)
29 .uri(uri)
30 .header("Content-Type", "application/json");
31
32 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
33 let mut response = client.send(request).await?;
34 if response.status().is_success() {
35 let reader = BufReader::new(response.into_body());
36 Ok(reader
37 .lines()
38 .filter_map(|line| async move {
39 match line {
40 Ok(line) => {
41 if let Some(line) = line.strip_prefix("data: ") {
42 match serde_json::from_str(line) {
43 Ok(response) => Some(Ok(response)),
44 Err(error) => Some(Err(anyhow!(format!(
45 "Error parsing JSON: {error:?}\n{line:?}"
46 )))),
47 }
48 } else {
49 None
50 }
51 }
52 Err(error) => Some(Err(anyhow!(error))),
53 }
54 })
55 .boxed())
56 } else {
57 let mut text = String::new();
58 response.body_mut().read_to_string(&mut text).await?;
59 Err(anyhow!(
60 "error during streamGenerateContent, status code: {:?}, body: {}",
61 response.status(),
62 text
63 ))
64 }
65}
66
67pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
68 if request.model.is_empty() {
69 bail!("Model must be specified");
70 }
71
72 if request.contents.is_empty() {
73 bail!("Request must contain at least one content item");
74 }
75
76 if let Some(user_content) = request
77 .contents
78 .iter()
79 .find(|content| content.role == Role::User)
80 && user_content.parts.is_empty()
81 {
82 bail!("User content must contain at least one part");
83 }
84
85 Ok(())
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89pub enum Task {
90 #[serde(rename = "generateContent")]
91 GenerateContent,
92 #[serde(rename = "streamGenerateContent")]
93 StreamGenerateContent,
94 #[serde(rename = "embedContent")]
95 EmbedContent,
96 #[serde(rename = "batchEmbedContents")]
97 BatchEmbedContents,
98}
99
100#[derive(Debug, Serialize, Deserialize)]
101#[serde(rename_all = "camelCase")]
102pub struct GenerateContentRequest {
103 #[serde(default, skip_serializing_if = "ModelName::is_empty")]
104 pub model: ModelName,
105 pub contents: Vec<Content>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub system_instruction: Option<SystemInstruction>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub generation_config: Option<GenerationConfig>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub safety_settings: Option<Vec<SafetySetting>>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub tools: Option<Vec<Tool>>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub tool_config: Option<ToolConfig>,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct GenerateContentResponse {
121 #[serde(skip_serializing_if = "Option::is_none")]
122 pub candidates: Option<Vec<GenerateContentCandidate>>,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 pub prompt_feedback: Option<PromptFeedback>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 pub usage_metadata: Option<UsageMetadata>,
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130#[serde(rename_all = "camelCase")]
131pub struct GenerateContentCandidate {
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub index: Option<usize>,
134 pub content: Content,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub finish_reason: Option<String>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub finish_message: Option<String>,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub safety_ratings: Option<Vec<SafetyRating>>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub citation_metadata: Option<CitationMetadata>,
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct Content {
148 #[serde(default)]
149 pub parts: Vec<Part>,
150 pub role: Role,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct SystemInstruction {
156 pub parts: Vec<Part>,
157}
158
159#[derive(Debug, PartialEq, Deserialize, Serialize)]
160#[serde(rename_all = "camelCase")]
161pub enum Role {
162 User,
163 Model,
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167#[serde(untagged)]
168pub enum Part {
169 TextPart(TextPart),
170 InlineDataPart(InlineDataPart),
171 FunctionCallPart(FunctionCallPart),
172 FunctionResponsePart(FunctionResponsePart),
173 ThoughtPart(ThoughtPart),
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct TextPart {
179 pub text: String,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183#[serde(rename_all = "camelCase")]
184pub struct InlineDataPart {
185 pub inline_data: GenerativeContentBlob,
186}
187
188#[derive(Debug, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct GenerativeContentBlob {
191 pub mime_type: String,
192 pub data: String,
193}
194
195#[derive(Debug, Serialize, Deserialize)]
196#[serde(rename_all = "camelCase")]
197pub struct FunctionCallPart {
198 pub function_call: FunctionCall,
199 /// Thought signature returned by the model for function calls.
200 /// Only present on the first function call in parallel call scenarios.
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub thought_signature: Option<String>,
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206#[serde(rename_all = "camelCase")]
207pub struct FunctionResponsePart {
208 pub function_response: FunctionResponse,
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct ThoughtPart {
214 pub thought: bool,
215 pub thought_signature: String,
216}
217
218#[derive(Debug, Serialize, Deserialize)]
219#[serde(rename_all = "camelCase")]
220pub struct CitationSource {
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub start_index: Option<usize>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 pub end_index: Option<usize>,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 pub uri: Option<String>,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 pub license: Option<String>,
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232#[serde(rename_all = "camelCase")]
233pub struct CitationMetadata {
234 pub citation_sources: Vec<CitationSource>,
235}
236
237#[derive(Debug, Serialize, Deserialize)]
238#[serde(rename_all = "camelCase")]
239pub struct PromptFeedback {
240 #[serde(skip_serializing_if = "Option::is_none")]
241 pub block_reason: Option<String>,
242 pub safety_ratings: Option<Vec<SafetyRating>>,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub block_reason_message: Option<String>,
245}
246
247#[derive(Debug, Serialize, Deserialize, Default)]
248#[serde(rename_all = "camelCase")]
249pub struct UsageMetadata {
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub prompt_token_count: Option<u64>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 pub cached_content_token_count: Option<u64>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 pub candidates_token_count: Option<u64>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub tool_use_prompt_token_count: Option<u64>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 pub thoughts_token_count: Option<u64>,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub total_token_count: Option<u64>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct ThinkingConfig {
267 pub thinking_budget: u32,
268}
269
270#[derive(Debug, Deserialize, Serialize)]
271#[serde(rename_all = "camelCase")]
272pub struct GenerationConfig {
273 #[serde(skip_serializing_if = "Option::is_none")]
274 pub candidate_count: Option<usize>,
275 #[serde(skip_serializing_if = "Option::is_none")]
276 pub stop_sequences: Option<Vec<String>>,
277 #[serde(skip_serializing_if = "Option::is_none")]
278 pub max_output_tokens: Option<usize>,
279 #[serde(skip_serializing_if = "Option::is_none")]
280 pub temperature: Option<f64>,
281 #[serde(skip_serializing_if = "Option::is_none")]
282 pub top_p: Option<f64>,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub top_k: Option<usize>,
285 #[serde(skip_serializing_if = "Option::is_none")]
286 pub thinking_config: Option<ThinkingConfig>,
287}
288
289#[derive(Debug, Serialize, Deserialize)]
290#[serde(rename_all = "camelCase")]
291pub struct SafetySetting {
292 pub category: HarmCategory,
293 pub threshold: HarmBlockThreshold,
294}
295
296#[derive(Debug, Serialize, Deserialize)]
297pub enum HarmCategory {
298 #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
299 Unspecified,
300 #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
301 Derogatory,
302 #[serde(rename = "HARM_CATEGORY_TOXICITY")]
303 Toxicity,
304 #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
305 Violence,
306 #[serde(rename = "HARM_CATEGORY_SEXUAL")]
307 Sexual,
308 #[serde(rename = "HARM_CATEGORY_MEDICAL")]
309 Medical,
310 #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
311 Dangerous,
312 #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
313 Harassment,
314 #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
315 HateSpeech,
316 #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
317 SexuallyExplicit,
318 #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
319 DangerousContent,
320}
321
322#[derive(Debug, Serialize, Deserialize)]
323#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
324pub enum HarmBlockThreshold {
325 #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
326 Unspecified,
327 BlockLowAndAbove,
328 BlockMediumAndAbove,
329 BlockOnlyHigh,
330 BlockNone,
331}
332
333#[derive(Debug, Serialize, Deserialize)]
334#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
335pub enum HarmProbability {
336 #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
337 Unspecified,
338 Negligible,
339 Low,
340 Medium,
341 High,
342}
343
344#[derive(Debug, Serialize, Deserialize)]
345#[serde(rename_all = "camelCase")]
346pub struct SafetyRating {
347 pub category: HarmCategory,
348 pub probability: HarmProbability,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352pub struct FunctionCall {
353 pub name: String,
354 pub args: serde_json::Value,
355}
356
357#[derive(Debug, Serialize, Deserialize)]
358pub struct FunctionResponse {
359 pub name: String,
360 pub response: serde_json::Value,
361}
362
363#[derive(Debug, Serialize, Deserialize)]
364#[serde(rename_all = "camelCase")]
365pub struct Tool {
366 pub function_declarations: Vec<FunctionDeclaration>,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370#[serde(rename_all = "camelCase")]
371pub struct ToolConfig {
372 pub function_calling_config: FunctionCallingConfig,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376#[serde(rename_all = "camelCase")]
377pub struct FunctionCallingConfig {
378 pub mode: FunctionCallingMode,
379 #[serde(skip_serializing_if = "Option::is_none")]
380 pub allowed_function_names: Option<Vec<String>>,
381}
382
383#[derive(Debug, Serialize, Deserialize)]
384#[serde(rename_all = "lowercase")]
385pub enum FunctionCallingMode {
386 Auto,
387 Any,
388 None,
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct FunctionDeclaration {
393 pub name: String,
394 pub description: String,
395 pub parameters: serde_json::Value,
396}
397
398#[derive(Debug, Default)]
399pub struct ModelName {
400 pub model_id: String,
401}
402
403impl ModelName {
404 pub fn is_empty(&self) -> bool {
405 self.model_id.is_empty()
406 }
407}
408
409const MODEL_NAME_PREFIX: &str = "models/";
410
411impl Serialize for ModelName {
412 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
413 where
414 S: Serializer,
415 {
416 serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
417 }
418}
419
420impl<'de> Deserialize<'de> for ModelName {
421 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
422 where
423 D: Deserializer<'de>,
424 {
425 let string = String::deserialize(deserializer)?;
426 if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
427 Ok(Self {
428 model_id: id.to_string(),
429 })
430 } else {
431 Err(serde::de::Error::custom(format!(
432 "Expected model name to begin with {}, got: {}",
433 MODEL_NAME_PREFIX, string
434 )))
435 }
436 }
437}
438
439#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
440#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
441pub enum Model {
442 #[serde(
443 rename = "gemini-2.5-flash-lite",
444 alias = "gemini-2.5-flash-lite-preview-06-17",
445 alias = "gemini-2.0-flash-lite-preview"
446 )]
447 Gemini25FlashLite,
448 #[serde(
449 rename = "gemini-2.5-flash",
450 alias = "gemini-2.0-flash-thinking-exp",
451 alias = "gemini-2.5-flash-preview-04-17",
452 alias = "gemini-2.5-flash-preview-05-20",
453 alias = "gemini-2.5-flash-preview-latest",
454 alias = "gemini-2.0-flash"
455 )]
456 #[default]
457 Gemini25Flash,
458 #[serde(
459 rename = "gemini-2.5-pro",
460 alias = "gemini-2.0-pro-exp",
461 alias = "gemini-2.5-pro-preview-latest",
462 alias = "gemini-2.5-pro-exp-03-25",
463 alias = "gemini-2.5-pro-preview-03-25",
464 alias = "gemini-2.5-pro-preview-05-06",
465 alias = "gemini-2.5-pro-preview-06-05"
466 )]
467 Gemini25Pro,
468 #[serde(rename = "gemini-3-flash-preview")]
469 Gemini3Flash,
470 #[serde(rename = "gemini-3.1-pro-preview", alias = "gemini-3-pro-preview")]
471 Gemini31Pro,
472 #[serde(rename = "custom")]
473 Custom {
474 name: String,
475 /// The name displayed in the UI, such as in the agent panel model dropdown menu.
476 display_name: Option<String>,
477 max_tokens: u64,
478 #[serde(default)]
479 mode: GoogleModelMode,
480 },
481}
482
483impl Model {
484 pub fn default_fast() -> Self {
485 Self::Gemini25FlashLite
486 }
487
488 pub fn id(&self) -> &str {
489 match self {
490 Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
491 Self::Gemini25Flash => "gemini-2.5-flash",
492 Self::Gemini25Pro => "gemini-2.5-pro",
493 Self::Gemini3Flash => "gemini-3-flash-preview",
494 Self::Gemini31Pro => "gemini-3.1-pro-preview",
495 Self::Custom { name, .. } => name,
496 }
497 }
498 pub fn request_id(&self) -> &str {
499 match self {
500 Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
501 Self::Gemini25Flash => "gemini-2.5-flash",
502 Self::Gemini25Pro => "gemini-2.5-pro",
503 Self::Gemini3Flash => "gemini-3-flash-preview",
504 Self::Gemini31Pro => "gemini-3.1-pro-preview",
505 Self::Custom { name, .. } => name,
506 }
507 }
508
509 pub fn display_name(&self) -> &str {
510 match self {
511 Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
512 Self::Gemini25Flash => "Gemini 2.5 Flash",
513 Self::Gemini25Pro => "Gemini 2.5 Pro",
514 Self::Gemini3Flash => "Gemini 3 Flash",
515 Self::Gemini31Pro => "Gemini 3.1 Pro",
516 Self::Custom {
517 name, display_name, ..
518 } => display_name.as_ref().unwrap_or(name),
519 }
520 }
521
522 pub fn max_token_count(&self) -> u64 {
523 match self {
524 Self::Gemini25FlashLite
525 | Self::Gemini25Flash
526 | Self::Gemini25Pro
527 | Self::Gemini3Flash
528 | Self::Gemini31Pro => 1_048_576,
529 Self::Custom { max_tokens, .. } => *max_tokens,
530 }
531 }
532
533 pub fn max_output_tokens(&self) -> Option<u64> {
534 match self {
535 Model::Gemini25FlashLite
536 | Model::Gemini25Flash
537 | Model::Gemini25Pro
538 | Model::Gemini3Flash
539 | Model::Gemini31Pro => Some(65_536),
540 Model::Custom { .. } => None,
541 }
542 }
543
544 pub fn supports_tools(&self) -> bool {
545 true
546 }
547
548 pub fn supports_images(&self) -> bool {
549 true
550 }
551
552 pub fn mode(&self) -> GoogleModelMode {
553 match self {
554 Self::Gemini25FlashLite | Self::Gemini25Flash | Self::Gemini25Pro => {
555 GoogleModelMode::Thinking {
556 // By default these models are set to "auto", so we preserve that behavior
557 // but indicate they are capable of thinking mode
558 budget_tokens: None,
559 }
560 }
561 Self::Gemini3Flash => GoogleModelMode::Default,
562 Self::Gemini31Pro => GoogleModelMode::Thinking {
563 budget_tokens: None,
564 },
565 Self::Custom { mode, .. } => *mode,
566 }
567 }
568}
569
570impl std::fmt::Display for Model {
571 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
572 write!(f, "{}", self.id())
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use serde_json::json;
580
581 #[test]
582 fn test_function_call_part_with_signature_serializes_correctly() {
583 let part = FunctionCallPart {
584 function_call: FunctionCall {
585 name: "test_function".to_string(),
586 args: json!({"arg": "value"}),
587 },
588 thought_signature: Some("test_signature".to_string()),
589 };
590
591 let serialized = serde_json::to_value(&part).unwrap();
592
593 assert_eq!(serialized["functionCall"]["name"], "test_function");
594 assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
595 assert_eq!(serialized["thoughtSignature"], "test_signature");
596 }
597
598 #[test]
599 fn test_function_call_part_without_signature_omits_field() {
600 let part = FunctionCallPart {
601 function_call: FunctionCall {
602 name: "test_function".to_string(),
603 args: json!({"arg": "value"}),
604 },
605 thought_signature: None,
606 };
607
608 let serialized = serde_json::to_value(&part).unwrap();
609
610 assert_eq!(serialized["functionCall"]["name"], "test_function");
611 assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
612 // thoughtSignature field should not be present when None
613 assert!(serialized.get("thoughtSignature").is_none());
614 }
615
616 #[test]
617 fn test_function_call_part_deserializes_with_signature() {
618 let json = json!({
619 "functionCall": {
620 "name": "test_function",
621 "args": {"arg": "value"}
622 },
623 "thoughtSignature": "test_signature"
624 });
625
626 let part: FunctionCallPart = serde_json::from_value(json).unwrap();
627
628 assert_eq!(part.function_call.name, "test_function");
629 assert_eq!(part.thought_signature, Some("test_signature".to_string()));
630 }
631
632 #[test]
633 fn test_function_call_part_deserializes_without_signature() {
634 let json = json!({
635 "functionCall": {
636 "name": "test_function",
637 "args": {"arg": "value"}
638 }
639 });
640
641 let part: FunctionCallPart = serde_json::from_value(json).unwrap();
642
643 assert_eq!(part.function_call.name, "test_function");
644 assert_eq!(part.thought_signature, None);
645 }
646
647 #[test]
648 fn test_function_call_part_round_trip() {
649 let original = FunctionCallPart {
650 function_call: FunctionCall {
651 name: "test_function".to_string(),
652 args: json!({"arg": "value", "nested": {"key": "val"}}),
653 },
654 thought_signature: Some("round_trip_signature".to_string()),
655 };
656
657 let serialized = serde_json::to_value(&original).unwrap();
658 let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
659
660 assert_eq!(deserialized.function_call.name, original.function_call.name);
661 assert_eq!(deserialized.function_call.args, original.function_call.args);
662 assert_eq!(deserialized.thought_signature, original.thought_signature);
663 }
664
665 #[test]
666 fn test_function_call_part_with_empty_signature_serializes() {
667 let part = FunctionCallPart {
668 function_call: FunctionCall {
669 name: "test_function".to_string(),
670 args: json!({"arg": "value"}),
671 },
672 thought_signature: Some("".to_string()),
673 };
674
675 let serialized = serde_json::to_value(&part).unwrap();
676
677 // Empty string should still be serialized (normalization happens at a higher level)
678 assert_eq!(serialized["thoughtSignature"], "");
679 }
680}