google_ai.rs

  1use std::collections::HashMap;
  2
  3use serde::{Deserialize, Deserializer, Serialize, Serializer};
  4use zed_extension_api::{
  5    self as zed, http_client::HttpMethod, http_client::HttpRequest, llm_get_env_var,
  6    LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmMessageContent,
  7    LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, LlmStopReason,
  8    LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
  9};
 10
 11pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 12
 13fn stream_generate_content(
 14    model_id: &str,
 15    request: &LlmCompletionRequest,
 16    streams: &mut HashMap<String, StreamState>,
 17    next_stream_id: &mut u64,
 18) -> Result<String, String> {
 19    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
 20
 21    let generate_content_request = build_generate_content_request(model_id, request)?;
 22    validate_generate_content_request(&generate_content_request)?;
 23
 24    let uri = format!(
 25        "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
 26        API_URL, model_id, api_key
 27    );
 28
 29    let body = serde_json::to_vec(&generate_content_request)
 30        .map_err(|e| format!("Failed to serialize request: {}", e))?;
 31
 32    let http_request = HttpRequest::builder()
 33        .method(HttpMethod::Post)
 34        .url(&uri)
 35        .header("Content-Type", "application/json")
 36        .body(body)
 37        .build()?;
 38
 39    let response_stream = http_request.fetch_stream()?;
 40
 41    let stream_id = format!("stream-{}", *next_stream_id);
 42    *next_stream_id += 1;
 43
 44    streams.insert(
 45        stream_id.clone(),
 46        StreamState {
 47            response_stream,
 48            buffer: String::new(),
 49            usage: None,
 50        },
 51    );
 52
 53    Ok(stream_id)
 54}
 55
 56fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
 57    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
 58
 59    let generate_content_request = build_generate_content_request(model_id, request)?;
 60    validate_generate_content_request(&generate_content_request)?;
 61    let count_request = CountTokensRequest {
 62        generate_content_request,
 63    };
 64
 65    let uri = format!(
 66        "{}/v1beta/models/{}:countTokens?key={}",
 67        API_URL, model_id, api_key
 68    );
 69
 70    let body = serde_json::to_vec(&count_request)
 71        .map_err(|e| format!("Failed to serialize request: {}", e))?;
 72
 73    let http_request = HttpRequest::builder()
 74        .method(HttpMethod::Post)
 75        .url(&uri)
 76        .header("Content-Type", "application/json")
 77        .body(body)
 78        .build()?;
 79
 80    let response = http_request.fetch()?;
 81    let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
 82        .map_err(|e| format!("Failed to parse response: {}", e))?;
 83
 84    Ok(response_body.total_tokens)
 85}
 86
 87fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
 88    if request.model.is_empty() {
 89        return Err("Model must be specified".to_string());
 90    }
 91
 92    if request.contents.is_empty() {
 93        return Err("Request must contain at least one content item".to_string());
 94    }
 95
 96    if let Some(user_content) = request
 97        .contents
 98        .iter()
 99        .find(|content| content.role == Role::User)
100    {
101        if user_content.parts.is_empty() {
102            return Err("User content must contain at least one part".to_string());
103        }
104    }
105
106    Ok(())
107}
108
109// Extension implementation
110
111const PROVIDER_ID: &str = "google-ai";
112const PROVIDER_NAME: &str = "Google AI";
113
114struct GoogleAiExtension {
115    streams: HashMap<String, StreamState>,
116    next_stream_id: u64,
117}
118
119struct StreamState {
120    response_stream: zed::http_client::HttpResponseStream,
121    buffer: String,
122    usage: Option<UsageMetadata>,
123}
124
125impl zed::Extension for GoogleAiExtension {
126    fn new() -> Self {
127        Self {
128            streams: HashMap::new(),
129            next_stream_id: 0,
130        }
131    }
132
133    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
134        vec![LlmProviderInfo {
135            id: PROVIDER_ID.to_string(),
136            name: PROVIDER_NAME.to_string(),
137            icon: Some("icons/google-ai.svg".to_string()),
138        }]
139    }
140
141    fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
142        if provider_id != PROVIDER_ID {
143            return Err(format!("Unknown provider: {}", provider_id));
144        }
145        Ok(get_models())
146    }
147
148    fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
149        if provider_id != PROVIDER_ID {
150            return None;
151        }
152
153        Some(
154            r#"## Google AI Setup
155
156To use Google AI models in Zed, you need a Gemini API key.
157
1581. Go to [Google AI Studio](https://aistudio.google.com/apikey)
1592. Create or select a project
1603. Generate an API key
1614. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
162
163You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
164"#
165            .to_string(),
166        )
167    }
168
169    fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
170        if provider_id != PROVIDER_ID {
171            return false;
172        }
173        get_api_key().is_some()
174    }
175
176    fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
177        if provider_id != PROVIDER_ID {
178            return Err(format!("Unknown provider: {}", provider_id));
179        }
180        Ok(())
181    }
182
183    fn llm_count_tokens(
184        &self,
185        provider_id: &str,
186        model_id: &str,
187        request: &LlmCompletionRequest,
188    ) -> Result<u64, String> {
189        if provider_id != PROVIDER_ID {
190            return Err(format!("Unknown provider: {}", provider_id));
191        }
192        count_tokens(model_id, request)
193    }
194
195    fn llm_stream_completion_start(
196        &mut self,
197        provider_id: &str,
198        model_id: &str,
199        request: &LlmCompletionRequest,
200    ) -> Result<String, String> {
201        if provider_id != PROVIDER_ID {
202            return Err(format!("Unknown provider: {}", provider_id));
203        }
204        stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
205    }
206
207    fn llm_stream_completion_next(
208        &mut self,
209        stream_id: &str,
210    ) -> Result<Option<LlmCompletionEvent>, String> {
211        stream_generate_content_next(stream_id, &mut self.streams)
212    }
213
214    fn llm_stream_completion_close(&mut self, stream_id: &str) {
215        self.streams.remove(stream_id);
216    }
217
218    fn llm_cache_configuration(
219        &self,
220        provider_id: &str,
221        _model_id: &str,
222    ) -> Option<LlmCacheConfiguration> {
223        if provider_id != PROVIDER_ID {
224            return None;
225        }
226
227        Some(LlmCacheConfiguration {
228            max_cache_anchors: 1,
229            should_cache_tool_definitions: false,
230            min_total_token_count: 32768,
231        })
232    }
233}
234
235zed::register_extension!(GoogleAiExtension);
236
237// Helper functions
238
239fn get_api_key() -> Option<String> {
240    llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
241}
242
243fn get_models() -> Vec<LlmModelInfo> {
244    vec![
245        LlmModelInfo {
246            id: "gemini-2.5-flash-lite".to_string(),
247            name: "Gemini 2.5 Flash-Lite".to_string(),
248            max_token_count: 1_048_576,
249            max_output_tokens: Some(65_536),
250            capabilities: LlmModelCapabilities {
251                supports_images: true,
252                supports_tools: true,
253                supports_tool_choice_auto: true,
254                supports_tool_choice_any: true,
255                supports_tool_choice_none: true,
256                supports_thinking: true,
257                tool_input_format: LlmToolInputFormat::JsonSchema,
258            },
259            is_default: false,
260            is_default_fast: true,
261        },
262        LlmModelInfo {
263            id: "gemini-2.5-flash".to_string(),
264            name: "Gemini 2.5 Flash".to_string(),
265            max_token_count: 1_048_576,
266            max_output_tokens: Some(65_536),
267            capabilities: LlmModelCapabilities {
268                supports_images: true,
269                supports_tools: true,
270                supports_tool_choice_auto: true,
271                supports_tool_choice_any: true,
272                supports_tool_choice_none: true,
273                supports_thinking: true,
274                tool_input_format: LlmToolInputFormat::JsonSchema,
275            },
276            is_default: true,
277            is_default_fast: false,
278        },
279        LlmModelInfo {
280            id: "gemini-2.5-pro".to_string(),
281            name: "Gemini 2.5 Pro".to_string(),
282            max_token_count: 1_048_576,
283            max_output_tokens: Some(65_536),
284            capabilities: LlmModelCapabilities {
285                supports_images: true,
286                supports_tools: true,
287                supports_tool_choice_auto: true,
288                supports_tool_choice_any: true,
289                supports_tool_choice_none: true,
290                supports_thinking: true,
291                tool_input_format: LlmToolInputFormat::JsonSchema,
292            },
293            is_default: false,
294            is_default_fast: false,
295        },
296        LlmModelInfo {
297            id: "gemini-3-pro-preview".to_string(),
298            name: "Gemini 3 Pro".to_string(),
299            max_token_count: 1_048_576,
300            max_output_tokens: Some(65_536),
301            capabilities: LlmModelCapabilities {
302                supports_images: true,
303                supports_tools: true,
304                supports_tool_choice_auto: true,
305                supports_tool_choice_any: true,
306                supports_tool_choice_none: true,
307                supports_thinking: true,
308                tool_input_format: LlmToolInputFormat::JsonSchema,
309            },
310            is_default: false,
311            is_default_fast: false,
312        },
313        LlmModelInfo {
314            id: "gemini-3-flash-preview".to_string(),
315            name: "Gemini 3 Flash".to_string(),
316            max_token_count: 1_048_576,
317            max_output_tokens: Some(65_536),
318            capabilities: LlmModelCapabilities {
319                supports_images: true,
320                supports_tools: true,
321                supports_tool_choice_auto: true,
322                supports_tool_choice_any: true,
323                supports_tool_choice_none: true,
324                supports_thinking: true,
325                tool_input_format: LlmToolInputFormat::JsonSchema,
326            },
327            is_default: false,
328            is_default_fast: false,
329        },
330    ]
331}
332
333fn stream_generate_content_next(
334    stream_id: &str,
335    streams: &mut HashMap<String, StreamState>,
336) -> Result<Option<LlmCompletionEvent>, String> {
337    let state = streams
338        .get_mut(stream_id)
339        .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
340
341    loop {
342        if let Some(newline_pos) = state.buffer.find('\n') {
343            let line = state.buffer[..newline_pos].to_string();
344            state.buffer = state.buffer[newline_pos + 1..].to_string();
345
346            if let Some(data) = line.strip_prefix("data: ") {
347                if data.trim().is_empty() {
348                    continue;
349                }
350
351                let response: GenerateContentResponse = serde_json::from_str(data)
352                    .map_err(|e| format!("Failed to parse SSE data: {} - {}", e, data))?;
353
354                if let Some(usage) = response.usage_metadata {
355                    state.usage = Some(usage);
356                }
357
358                if let Some(candidates) = response.candidates {
359                    for candidate in candidates {
360                        for part in candidate.content.parts {
361                            match part {
362                                Part::TextPart(text_part) => {
363                                    return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
364                                }
365                                Part::ThoughtPart(thought_part) => {
366                                    return Ok(Some(LlmCompletionEvent::Thinking(
367                                        LlmThinkingContent {
368                                            text: String::new(),
369                                            signature: Some(thought_part.thought_signature),
370                                        },
371                                    )));
372                                }
373                                Part::FunctionCallPart(fc_part) => {
374                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
375                                        id: fc_part.function_call.name.clone(),
376                                        name: fc_part.function_call.name,
377                                        input: serde_json::to_string(&fc_part.function_call.args)
378                                            .unwrap_or_default(),
379                                        is_input_complete: true,
380                                        thought_signature: fc_part.thought_signature,
381                                    })));
382                                }
383                                _ => {}
384                            }
385                        }
386
387                        if let Some(finish_reason) = candidate.finish_reason {
388                            let stop_reason = match finish_reason.as_str() {
389                                "STOP" => LlmStopReason::EndTurn,
390                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
391                                "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
392                                "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
393                                _ => LlmStopReason::EndTurn,
394                            };
395
396                            if let Some(usage) = state.usage.take() {
397                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
398                                    input_tokens: usage.prompt_token_count.unwrap_or(0),
399                                    output_tokens: usage.candidates_token_count.unwrap_or(0),
400                                    cache_creation_input_tokens: None,
401                                    cache_read_input_tokens: usage.cached_content_token_count,
402                                })));
403                            }
404
405                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
406                        }
407                    }
408                }
409            }
410
411            continue;
412        }
413
414        match state.response_stream.next_chunk() {
415            Ok(Some(chunk)) => {
416                let chunk_str = String::from_utf8_lossy(&chunk);
417                state.buffer.push_str(&chunk_str);
418            }
419            Ok(None) => {
420                streams.remove(stream_id);
421                return Ok(None);
422            }
423            Err(e) => {
424                streams.remove(stream_id);
425                return Err(e);
426            }
427        }
428    }
429}
430
431fn build_generate_content_request(
432    model_id: &str,
433    request: &LlmCompletionRequest,
434) -> Result<GenerateContentRequest, String> {
435    let mut contents: Vec<Content> = Vec::new();
436    let mut system_instruction: Option<SystemInstruction> = None;
437
438    for message in &request.messages {
439        match message.role {
440            LlmMessageRole::System => {
441                let parts = convert_content_to_parts(&message.content)?;
442                system_instruction = Some(SystemInstruction { parts });
443            }
444            LlmMessageRole::User | LlmMessageRole::Assistant => {
445                let role = match message.role {
446                    LlmMessageRole::User => Role::User,
447                    LlmMessageRole::Assistant => Role::Model,
448                    _ => continue,
449                };
450                let parts = convert_content_to_parts(&message.content)?;
451                contents.push(Content { parts, role });
452            }
453        }
454    }
455
456    let tools = if !request.tools.is_empty() {
457        Some(vec![Tool {
458            function_declarations: request
459                .tools
460                .iter()
461                .map(|t| FunctionDeclaration {
462                    name: t.name.clone(),
463                    description: t.description.clone(),
464                    parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
465                })
466                .collect(),
467        }])
468    } else {
469        None
470    };
471
472    let tool_config = request.tool_choice.as_ref().map(|choice| {
473        let mode = match choice {
474            zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
475            zed::LlmToolChoice::Any => FunctionCallingMode::Any,
476            zed::LlmToolChoice::None => FunctionCallingMode::None,
477        };
478        ToolConfig {
479            function_calling_config: FunctionCallingConfig {
480                mode,
481                allowed_function_names: None,
482            },
483        }
484    });
485
486    let generation_config = Some(GenerationConfig {
487        candidate_count: Some(1),
488        stop_sequences: if request.stop_sequences.is_empty() {
489            None
490        } else {
491            Some(request.stop_sequences.clone())
492        },
493        max_output_tokens: request.max_tokens.map(|t| t as usize),
494        temperature: request.temperature.map(|t| t as f64),
495        top_p: None,
496        top_k: None,
497        thinking_config: if request.thinking_allowed {
498            Some(ThinkingConfig {
499                thinking_budget: 8192,
500            })
501        } else {
502            None
503        },
504    });
505
506    Ok(GenerateContentRequest {
507        model: ModelName {
508            model_id: model_id.to_string(),
509        },
510        contents,
511        system_instruction,
512        generation_config,
513        safety_settings: None,
514        tools,
515        tool_config,
516    })
517}
518
519fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
520    let mut parts = Vec::new();
521
522    for item in content {
523        match item {
524            LlmMessageContent::Text(text) => {
525                parts.push(Part::TextPart(TextPart { text: text.clone() }));
526            }
527            LlmMessageContent::Image(image) => {
528                parts.push(Part::InlineDataPart(InlineDataPart {
529                    inline_data: GenerativeContentBlob {
530                        mime_type: "image/png".to_string(),
531                        data: image.source.clone(),
532                    },
533                }));
534            }
535            LlmMessageContent::ToolUse(tool_use) => {
536                parts.push(Part::FunctionCallPart(FunctionCallPart {
537                    function_call: FunctionCall {
538                        name: tool_use.name.clone(),
539                        args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
540                    },
541                    thought_signature: tool_use.thought_signature.clone(),
542                }));
543            }
544            LlmMessageContent::ToolResult(tool_result) => {
545                let response_value = match &tool_result.content {
546                    zed::LlmToolResultContent::Text(text) => {
547                        serde_json::json!({ "result": text })
548                    }
549                    zed::LlmToolResultContent::Image(_) => {
550                        serde_json::json!({ "error": "Image results not supported" })
551                    }
552                };
553                parts.push(Part::FunctionResponsePart(FunctionResponsePart {
554                    function_response: FunctionResponse {
555                        name: tool_result.tool_name.clone(),
556                        response: response_value,
557                    },
558                }));
559            }
560            LlmMessageContent::Thinking(thinking) => {
561                if let Some(signature) = &thinking.signature {
562                    parts.push(Part::ThoughtPart(ThoughtPart {
563                        thought: true,
564                        thought_signature: signature.clone(),
565                    }));
566                }
567            }
568            LlmMessageContent::RedactedThinking(_) => {}
569        }
570    }
571
572    Ok(parts)
573}
574
575// Data structures for Google AI API
576
577#[derive(Debug, Serialize, Deserialize)]
578#[serde(rename_all = "camelCase")]
579pub struct GenerateContentRequest {
580    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
581    pub model: ModelName,
582    pub contents: Vec<Content>,
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub system_instruction: Option<SystemInstruction>,
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub generation_config: Option<GenerationConfig>,
587    #[serde(skip_serializing_if = "Option::is_none")]
588    pub safety_settings: Option<Vec<SafetySetting>>,
589    #[serde(skip_serializing_if = "Option::is_none")]
590    pub tools: Option<Vec<Tool>>,
591    #[serde(skip_serializing_if = "Option::is_none")]
592    pub tool_config: Option<ToolConfig>,
593}
594
595#[derive(Debug, Serialize, Deserialize)]
596#[serde(rename_all = "camelCase")]
597pub struct GenerateContentResponse {
598    #[serde(skip_serializing_if = "Option::is_none")]
599    pub candidates: Option<Vec<GenerateContentCandidate>>,
600    #[serde(skip_serializing_if = "Option::is_none")]
601    pub prompt_feedback: Option<PromptFeedback>,
602    #[serde(skip_serializing_if = "Option::is_none")]
603    pub usage_metadata: Option<UsageMetadata>,
604}
605
606#[derive(Debug, Serialize, Deserialize)]
607#[serde(rename_all = "camelCase")]
608pub struct GenerateContentCandidate {
609    #[serde(skip_serializing_if = "Option::is_none")]
610    pub index: Option<usize>,
611    pub content: Content,
612    #[serde(skip_serializing_if = "Option::is_none")]
613    pub finish_reason: Option<String>,
614    #[serde(skip_serializing_if = "Option::is_none")]
615    pub finish_message: Option<String>,
616    #[serde(skip_serializing_if = "Option::is_none")]
617    pub safety_ratings: Option<Vec<SafetyRating>>,
618    #[serde(skip_serializing_if = "Option::is_none")]
619    pub citation_metadata: Option<CitationMetadata>,
620}
621
622#[derive(Debug, Serialize, Deserialize)]
623#[serde(rename_all = "camelCase")]
624pub struct Content {
625    #[serde(default)]
626    pub parts: Vec<Part>,
627    pub role: Role,
628}
629
630#[derive(Debug, Serialize, Deserialize)]
631#[serde(rename_all = "camelCase")]
632pub struct SystemInstruction {
633    pub parts: Vec<Part>,
634}
635
636#[derive(Debug, PartialEq, Deserialize, Serialize)]
637#[serde(rename_all = "camelCase")]
638pub enum Role {
639    User,
640    Model,
641}
642
643#[derive(Debug, Serialize, Deserialize)]
644#[serde(untagged)]
645pub enum Part {
646    TextPart(TextPart),
647    InlineDataPart(InlineDataPart),
648    FunctionCallPart(FunctionCallPart),
649    FunctionResponsePart(FunctionResponsePart),
650    ThoughtPart(ThoughtPart),
651}
652
653#[derive(Debug, Serialize, Deserialize)]
654#[serde(rename_all = "camelCase")]
655pub struct TextPart {
656    pub text: String,
657}
658
659#[derive(Debug, Serialize, Deserialize)]
660#[serde(rename_all = "camelCase")]
661pub struct InlineDataPart {
662    pub inline_data: GenerativeContentBlob,
663}
664
665#[derive(Debug, Serialize, Deserialize)]
666#[serde(rename_all = "camelCase")]
667pub struct GenerativeContentBlob {
668    pub mime_type: String,
669    pub data: String,
670}
671
672#[derive(Debug, Serialize, Deserialize)]
673#[serde(rename_all = "camelCase")]
674pub struct FunctionCallPart {
675    pub function_call: FunctionCall,
676    /// Thought signature returned by the model for function calls.
677    /// Only present on the first function call in parallel call scenarios.
678    #[serde(skip_serializing_if = "Option::is_none")]
679    pub thought_signature: Option<String>,
680}
681
682#[derive(Debug, Serialize, Deserialize)]
683#[serde(rename_all = "camelCase")]
684pub struct FunctionResponsePart {
685    pub function_response: FunctionResponse,
686}
687
688#[derive(Debug, Serialize, Deserialize)]
689#[serde(rename_all = "camelCase")]
690pub struct ThoughtPart {
691    pub thought: bool,
692    pub thought_signature: String,
693}
694
695#[derive(Debug, Serialize, Deserialize)]
696#[serde(rename_all = "camelCase")]
697pub struct CitationSource {
698    #[serde(skip_serializing_if = "Option::is_none")]
699    pub start_index: Option<usize>,
700    #[serde(skip_serializing_if = "Option::is_none")]
701    pub end_index: Option<usize>,
702    #[serde(skip_serializing_if = "Option::is_none")]
703    pub uri: Option<String>,
704    #[serde(skip_serializing_if = "Option::is_none")]
705    pub license: Option<String>,
706}
707
708#[derive(Debug, Serialize, Deserialize)]
709#[serde(rename_all = "camelCase")]
710pub struct CitationMetadata {
711    pub citation_sources: Vec<CitationSource>,
712}
713
714#[derive(Debug, Serialize, Deserialize)]
715#[serde(rename_all = "camelCase")]
716pub struct PromptFeedback {
717    #[serde(skip_serializing_if = "Option::is_none")]
718    pub block_reason: Option<String>,
719    pub safety_ratings: Option<Vec<SafetyRating>>,
720    #[serde(skip_serializing_if = "Option::is_none")]
721    pub block_reason_message: Option<String>,
722}
723
724#[derive(Debug, Serialize, Deserialize, Default)]
725#[serde(rename_all = "camelCase")]
726pub struct UsageMetadata {
727    #[serde(skip_serializing_if = "Option::is_none")]
728    pub prompt_token_count: Option<u64>,
729    #[serde(skip_serializing_if = "Option::is_none")]
730    pub cached_content_token_count: Option<u64>,
731    #[serde(skip_serializing_if = "Option::is_none")]
732    pub candidates_token_count: Option<u64>,
733    #[serde(skip_serializing_if = "Option::is_none")]
734    pub tool_use_prompt_token_count: Option<u64>,
735    #[serde(skip_serializing_if = "Option::is_none")]
736    pub thoughts_token_count: Option<u64>,
737    #[serde(skip_serializing_if = "Option::is_none")]
738    pub total_token_count: Option<u64>,
739}
740
741#[derive(Debug, Serialize, Deserialize)]
742#[serde(rename_all = "camelCase")]
743pub struct ThinkingConfig {
744    pub thinking_budget: u32,
745}
746
747#[derive(Debug, Deserialize, Serialize)]
748#[serde(rename_all = "camelCase")]
749pub struct GenerationConfig {
750    #[serde(skip_serializing_if = "Option::is_none")]
751    pub candidate_count: Option<usize>,
752    #[serde(skip_serializing_if = "Option::is_none")]
753    pub stop_sequences: Option<Vec<String>>,
754    #[serde(skip_serializing_if = "Option::is_none")]
755    pub max_output_tokens: Option<usize>,
756    #[serde(skip_serializing_if = "Option::is_none")]
757    pub temperature: Option<f64>,
758    #[serde(skip_serializing_if = "Option::is_none")]
759    pub top_p: Option<f64>,
760    #[serde(skip_serializing_if = "Option::is_none")]
761    pub top_k: Option<usize>,
762    #[serde(skip_serializing_if = "Option::is_none")]
763    pub thinking_config: Option<ThinkingConfig>,
764}
765
766#[derive(Debug, Serialize, Deserialize)]
767#[serde(rename_all = "camelCase")]
768pub struct SafetySetting {
769    pub category: HarmCategory,
770    pub threshold: HarmBlockThreshold,
771}
772
773#[derive(Debug, Serialize, Deserialize)]
774pub enum HarmCategory {
775    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
776    Unspecified,
777    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
778    Derogatory,
779    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
780    Toxicity,
781    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
782    Violence,
783    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
784    Sexual,
785    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
786    Medical,
787    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
788    Dangerous,
789    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
790    Harassment,
791    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
792    HateSpeech,
793    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
794    SexuallyExplicit,
795    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
796    DangerousContent,
797}
798
799#[derive(Debug, Serialize, Deserialize)]
800#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
801pub enum HarmBlockThreshold {
802    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
803    Unspecified,
804    BlockLowAndAbove,
805    BlockMediumAndAbove,
806    BlockOnlyHigh,
807    BlockNone,
808}
809
810#[derive(Debug, Serialize, Deserialize)]
811#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
812pub enum HarmProbability {
813    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
814    Unspecified,
815    Negligible,
816    Low,
817    Medium,
818    High,
819}
820
821#[derive(Debug, Serialize, Deserialize)]
822#[serde(rename_all = "camelCase")]
823pub struct SafetyRating {
824    pub category: HarmCategory,
825    pub probability: HarmProbability,
826}
827
828#[derive(Debug, Serialize, Deserialize)]
829#[serde(rename_all = "camelCase")]
830pub struct CountTokensRequest {
831    pub generate_content_request: GenerateContentRequest,
832}
833
834#[derive(Debug, Serialize, Deserialize)]
835#[serde(rename_all = "camelCase")]
836pub struct CountTokensResponse {
837    pub total_tokens: u64,
838}
839
840#[derive(Debug, Serialize, Deserialize)]
841pub struct FunctionCall {
842    pub name: String,
843    pub args: serde_json::Value,
844}
845
846#[derive(Debug, Serialize, Deserialize)]
847pub struct FunctionResponse {
848    pub name: String,
849    pub response: serde_json::Value,
850}
851
852#[derive(Debug, Serialize, Deserialize)]
853#[serde(rename_all = "camelCase")]
854pub struct Tool {
855    pub function_declarations: Vec<FunctionDeclaration>,
856}
857
858#[derive(Debug, Serialize, Deserialize)]
859#[serde(rename_all = "camelCase")]
860pub struct ToolConfig {
861    pub function_calling_config: FunctionCallingConfig,
862}
863
864#[derive(Debug, Serialize, Deserialize)]
865#[serde(rename_all = "camelCase")]
866pub struct FunctionCallingConfig {
867    pub mode: FunctionCallingMode,
868    #[serde(skip_serializing_if = "Option::is_none")]
869    pub allowed_function_names: Option<Vec<String>>,
870}
871
872#[derive(Debug, Serialize, Deserialize)]
873#[serde(rename_all = "lowercase")]
874pub enum FunctionCallingMode {
875    Auto,
876    Any,
877    None,
878}
879
880#[derive(Debug, Serialize, Deserialize)]
881pub struct FunctionDeclaration {
882    pub name: String,
883    pub description: String,
884    pub parameters: serde_json::Value,
885}
886
887#[derive(Debug, Default)]
888pub struct ModelName {
889    pub model_id: String,
890}
891
892impl ModelName {
893    pub fn is_empty(&self) -> bool {
894        self.model_id.is_empty()
895    }
896}
897
898const MODEL_NAME_PREFIX: &str = "models/";
899
900impl Serialize for ModelName {
901    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
902    where
903        S: Serializer,
904    {
905        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
906    }
907}
908
909impl<'de> Deserialize<'de> for ModelName {
910    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
911    where
912        D: Deserializer<'de>,
913    {
914        let string = String::deserialize(deserializer)?;
915        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
916            Ok(Self {
917                model_id: id.to_string(),
918            })
919        } else {
920            Err(serde::de::Error::custom(format!(
921                "Expected model name to begin with {}, got: {}",
922                MODEL_NAME_PREFIX, string
923            )))
924        }
925    }
926}