wip override the google_ai extension with the hardcoded implementation

Richard Feldman and Mikayla Maki created

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>

Change summary

extensions/google-ai/src/google_ai.rs | 1252 +++++++++++++---------------
1 file changed, 586 insertions(+), 666 deletions(-)

Detailed changes

extensions/google-ai/src/google_ai.rs 🔗

@@ -1,798 +1,718 @@
-use std::collections::HashMap;
-use std::sync::atomic::{AtomicU64, Ordering};
-use std::sync::Mutex;
-
-use serde::{Deserialize, Serialize};
-use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
-use zed_extension_api::{self as zed, *};
-
-static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
-
-struct GoogleAiProvider {
-    streams: Mutex<HashMap<String, StreamState>>,
-    next_stream_id: Mutex<u64>,
-}
-
-struct StreamState {
-    response_stream: Option<HttpResponseStream>,
-    buffer: String,
-    started: bool,
-    stop_reason: Option<LlmStopReason>,
-    wants_tool_use: bool,
-}
-
-struct ModelDefinition {
-    real_id: &'static str,
-    display_name: &'static str,
-    max_tokens: u64,
-    max_output_tokens: Option<u64>,
-    supports_images: bool,
-    supports_thinking: bool,
-    is_default: bool,
-    is_default_fast: bool,
-}
-
-const MODELS: &[ModelDefinition] = &[
-    ModelDefinition {
-        real_id: "gemini-2.5-flash-lite",
-        display_name: "Gemini 2.5 Flash-Lite",
-        max_tokens: 1_048_576,
-        max_output_tokens: Some(65_536),
-        supports_images: true,
-        supports_thinking: true,
-        is_default: false,
-        is_default_fast: true,
-    },
-    ModelDefinition {
-        real_id: "gemini-2.5-flash",
-        display_name: "Gemini 2.5 Flash",
-        max_tokens: 1_048_576,
-        max_output_tokens: Some(65_536),
-        supports_images: true,
-        supports_thinking: true,
-        is_default: true,
-        is_default_fast: false,
-    },
-    ModelDefinition {
-        real_id: "gemini-2.5-pro",
-        display_name: "Gemini 2.5 Pro",
-        max_tokens: 1_048_576,
-        max_output_tokens: Some(65_536),
-        supports_images: true,
-        supports_thinking: true,
-        is_default: false,
-        is_default_fast: false,
-    },
-    ModelDefinition {
-        real_id: "gemini-3-pro-preview",
-        display_name: "Gemini 3 Pro",
-        max_tokens: 1_048_576,
-        max_output_tokens: Some(65_536),
-        supports_images: true,
-        supports_thinking: true,
-        is_default: false,
-        is_default_fast: false,
-    },
-];
 
-fn get_real_model_id(display_name: &str) -> Option<&'static str> {
-    MODELS
-        .iter()
-        .find(|m| m.display_name == display_name)
-        .map(|m| m.real_id)
+use std::mem;
+
+use anyhow::{Result, anyhow, bail};
+use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+pub use settings::ModelMode as GoogleModelMode;
+
+pub const API_URL: &str = "https://generativelanguage.googleapis.com";
+
+pub async fn stream_generate_content(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    mut request: GenerateContentRequest,
+) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
+    let api_key = api_key.trim();
+    validate_generate_content_request(&request)?;
+
+    // The `model` field is emptied as it is provided as a path parameter.
+    let model_id = mem::take(&mut request.model.model_id);
+
+    let uri =
+        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
+
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json");
+
+    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let mut response = client.send(request).await?;
+    if response.status().is_success() {
+        let reader = BufReader::new(response.into_body());
+        Ok(reader
+            .lines()
+            .filter_map(|line| async move {
+                match line {
+                    Ok(line) => {
+                        if let Some(line) = line.strip_prefix("data: ") {
+                            match serde_json::from_str(line) {
+                                Ok(response) => Some(Ok(response)),
+                                Err(error) => Some(Err(anyhow!(format!(
+                                    "Error parsing JSON: {error:?}\n{line:?}"
+                                )))),
+                            }
+                        } else {
+                            None
+                        }
+                    }
+                    Err(error) => Some(Err(anyhow!(error))),
+                }
+            })
+            .boxed())
+    } else {
+        let mut text = String::new();
+        response.body_mut().read_to_string(&mut text).await?;
+        Err(anyhow!(
+            "error during streamGenerateContent, status code: {:?}, body: {}",
+            response.status(),
+            text
+        ))
+    }
 }
 
-fn get_model_supports_thinking(display_name: &str) -> bool {
-    MODELS
-        .iter()
-        .find(|m| m.display_name == display_name)
-        .map(|m| m.supports_thinking)
-        .unwrap_or(false)
-}
-
-/// Adapts a JSON schema to be compatible with Google's API subset.
-/// Google only supports a specific subset of JSON Schema fields.
-/// See: https://ai.google.dev/api/caching#Schema
-fn adapt_schema_for_google(json: &mut serde_json::Value) {
-    adapt_schema_for_google_impl(json, true);
-}
-
-fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
-    if let serde_json::Value::Object(obj) = json {
-        // Google's Schema only supports these fields:
-        // type, format, title, description, nullable, enum, maxItems, minItems,
-        // properties, required, minProperties, maxProperties, minLength, maxLength,
-        // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
-        const ALLOWED_KEYS: &[&str] = &[
-            "type",
-            "format",
-            "title",
-            "description",
-            "nullable",
-            "enum",
-            "maxItems",
-            "minItems",
-            "properties",
-            "required",
-            "minProperties",
-            "maxProperties",
-            "minLength",
-            "maxLength",
-            "pattern",
-            "example",
-            "anyOf",
-            "propertyOrdering",
-            "default",
-            "items",
-            "minimum",
-            "maximum",
-        ];
-
-        // Convert oneOf to anyOf before filtering keys
-        if let Some(one_of) = obj.remove("oneOf") {
-            obj.insert("anyOf".to_string(), one_of);
-        }
+pub async fn count_tokens(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: CountTokensRequest,
+) -> Result<CountTokensResponse> {
+    validate_generate_content_request(&request.generate_content_request)?;
+
+    let uri = format!(
+        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
+        model_id = &request.generate_content_request.model.model_id,
+    );
+
+    let request = serde_json::to_string(&request)?;
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(&uri)
+        .header("Content-Type", "application/json");
+    let http_request = request_builder.body(AsyncBody::from(request))?;
+
+    let mut response = client.send(http_request).await?;
+    let mut text = String::new();
+    response.body_mut().read_to_string(&mut text).await?;
+    anyhow::ensure!(
+        response.status().is_success(),
+        "error during countTokens, status code: {:?}, body: {}",
+        response.status(),
+        text
+    );
+    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+}
 
-        // If type is an array (e.g., ["string", "null"]), take just the first type
-        if let Some(type_field) = obj.get_mut("type") {
-            if let serde_json::Value::Array(types) = type_field {
-                if let Some(first_type) = types.first().cloned() {
-                    *type_field = first_type;
-                }
-            }
-        }
+pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
+    if request.model.is_empty() {
+        bail!("Model must be specified");
+    }
 
-        // Only filter keys if this is a schema object, not a properties map
-        if is_schema {
-            obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
-        }
+    if request.contents.is_empty() {
+        bail!("Request must contain at least one content item");
+    }
 
-        // Recursively process nested values
-        // "properties" contains a map of property names -> schemas
-        // "items" and "anyOf" contain schemas directly
-        for (key, value) in obj.iter_mut() {
-            if key == "properties" {
-                // properties is a map of property_name -> schema
-                if let serde_json::Value::Object(props) = value {
-                    for (_, prop_schema) in props.iter_mut() {
-                        adapt_schema_for_google_impl(prop_schema, true);
-                    }
-                }
-            } else if key == "items" {
-                // items is a schema
-                adapt_schema_for_google_impl(value, true);
-            } else if key == "anyOf" {
-                // anyOf is an array of schemas
-                if let serde_json::Value::Array(arr) = value {
-                    for item in arr.iter_mut() {
-                        adapt_schema_for_google_impl(item, true);
-                    }
-                }
-            }
-        }
-    } else if let serde_json::Value::Array(arr) = json {
-        for item in arr.iter_mut() {
-            adapt_schema_for_google_impl(item, true);
-        }
+    if let Some(user_content) = request
+        .contents
+        .iter()
+        .find(|content| content.role == Role::User)
+        && user_content.parts.is_empty()
+    {
+        bail!("User content must contain at least one part");
     }
+
+    Ok(())
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Task {
+    #[serde(rename = "generateContent")]
+    GenerateContent,
+    #[serde(rename = "streamGenerateContent")]
+    StreamGenerateContent,
+    #[serde(rename = "countTokens")]
+    CountTokens,
+    #[serde(rename = "embedContent")]
+    EmbedContent,
+    #[serde(rename = "batchEmbedContents")]
+    BatchEmbedContents,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleRequest {
-    contents: Vec<GoogleContent>,
+pub struct GenerateContentRequest {
+    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
+    pub model: ModelName,
+    pub contents: Vec<Content>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub system_instruction: Option<SystemInstruction>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    system_instruction: Option<GoogleSystemInstruction>,
+    pub generation_config: Option<GenerationConfig>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    generation_config: Option<GoogleGenerationConfig>,
+    pub safety_settings: Option<Vec<SafetySetting>>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    tools: Option<Vec<GoogleTool>>,
+    pub tools: Option<Vec<Tool>>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    tool_config: Option<GoogleToolConfig>,
+    pub tool_config: Option<ToolConfig>,
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleSystemInstruction {
-    parts: Vec<GooglePart>,
+pub struct GenerateContentResponse {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub candidates: Option<Vec<GenerateContentCandidate>>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub prompt_feedback: Option<PromptFeedback>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub usage_metadata: Option<UsageMetadata>,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleContent {
-    parts: Vec<GooglePart>,
+pub struct GenerateContentCandidate {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub index: Option<usize>,
+    pub content: Content,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub finish_reason: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub finish_message: Option<String>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    role: Option<String>,
+    pub safety_ratings: Option<Vec<SafetyRating>>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub citation_metadata: Option<CitationMetadata>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Content {
+    #[serde(default)]
+    pub parts: Vec<Part>,
+    pub role: Role,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SystemInstruction {
+    pub parts: Vec<Part>,
+}
+
+#[derive(Debug, PartialEq, Deserialize, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum Role {
+    User,
+    Model,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(untagged)]
-enum GooglePart {
-    Text(GoogleTextPart),
-    InlineData(GoogleInlineDataPart),
-    FunctionCall(GoogleFunctionCallPart),
-    FunctionResponse(GoogleFunctionResponsePart),
-    Thought(GoogleThoughtPart),
+pub enum Part {
+    TextPart(TextPart),
+    InlineDataPart(InlineDataPart),
+    FunctionCallPart(FunctionCallPart),
+    FunctionResponsePart(FunctionResponsePart),
+    ThoughtPart(ThoughtPart),
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleTextPart {
-    text: String,
+pub struct TextPart {
+    pub text: String,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleInlineDataPart {
-    inline_data: GoogleBlob,
+pub struct InlineDataPart {
+    pub inline_data: GenerativeContentBlob,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleBlob {
-    mime_type: String,
-    data: String,
+pub struct GenerativeContentBlob {
+    pub mime_type: String,
+    pub data: String,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleFunctionCallPart {
-    function_call: GoogleFunctionCall,
+pub struct FunctionCallPart {
+    pub function_call: FunctionCall,
+    /// Thought signature returned by the model for function calls.
+    /// Only present on the first function call in parallel call scenarios.
     #[serde(skip_serializing_if = "Option::is_none")]
-    thought_signature: Option<String>,
+    pub thought_signature: Option<String>,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleFunctionCall {
-    name: String,
-    args: serde_json::Value,
+pub struct FunctionResponsePart {
+    pub function_response: FunctionResponse,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleFunctionResponsePart {
-    function_response: GoogleFunctionResponse,
+pub struct ThoughtPart {
+    pub thought: bool,
+    pub thought_signature: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CitationSource {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub start_index: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub end_index: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub uri: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub license: Option<String>,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleFunctionResponse {
-    name: String,
-    response: serde_json::Value,
+pub struct CitationMetadata {
+    pub citation_sources: Vec<CitationSource>,
 }
 
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleThoughtPart {
-    thought: bool,
-    thought_signature: String,
+pub struct PromptFeedback {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub block_reason: Option<String>,
+    pub safety_ratings: Option<Vec<SafetyRating>>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub block_reason_message: Option<String>,
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize, Default)]
 #[serde(rename_all = "camelCase")]
-struct GoogleGenerationConfig {
+pub struct UsageMetadata {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub prompt_token_count: Option<u64>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    candidate_count: Option<usize>,
+    pub cached_content_token_count: Option<u64>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    stop_sequences: Option<Vec<String>>,
+    pub candidates_token_count: Option<u64>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    max_output_tokens: Option<usize>,
+    pub tool_use_prompt_token_count: Option<u64>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    temperature: Option<f64>,
+    pub thoughts_token_count: Option<u64>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    thinking_config: Option<GoogleThinkingConfig>,
+    pub total_token_count: Option<u64>,
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleThinkingConfig {
-    thinking_budget: u32,
+pub struct ThinkingConfig {
+    pub thinking_budget: u32,
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Deserialize, Serialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleTool {
-    function_declarations: Vec<GoogleFunctionDeclaration>,
+pub struct GenerationConfig {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub candidate_count: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub stop_sequences: Option<Vec<String>>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub max_output_tokens: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub temperature: Option<f64>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub top_p: Option<f64>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub top_k: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub thinking_config: Option<ThinkingConfig>,
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleFunctionDeclaration {
-    name: String,
-    description: String,
-    parameters: serde_json::Value,
+pub struct SafetySetting {
+    pub category: HarmCategory,
+    pub threshold: HarmBlockThreshold,
 }
 
-#[derive(Serialize)]
-#[serde(rename_all = "camelCase")]
-struct GoogleToolConfig {
-    function_calling_config: GoogleFunctionCallingConfig,
+#[derive(Debug, Serialize, Deserialize)]
+pub enum HarmCategory {
+    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
+    Unspecified,
+    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
+    Derogatory,
+    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
+    Toxicity,
+    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
+    Violence,
+    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
+    Sexual,
+    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
+    Medical,
+    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
+    Dangerous,
+    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
+    Harassment,
+    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
+    HateSpeech,
+    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
+    SexuallyExplicit,
+    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
+    DangerousContent,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum HarmBlockThreshold {
+    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
+    Unspecified,
+    BlockLowAndAbove,
+    BlockMediumAndAbove,
+    BlockOnlyHigh,
+    BlockNone,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum HarmProbability {
+    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
+    Unspecified,
+    Negligible,
+    Low,
+    Medium,
+    High,
 }
 
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleFunctionCallingConfig {
-    mode: String,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    allowed_function_names: Option<Vec<String>>,
+pub struct SafetyRating {
+    pub category: HarmCategory,
+    pub probability: HarmProbability,
 }
 
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleStreamResponse {
-    #[serde(default)]
-    candidates: Vec<GoogleCandidate>,
-    #[serde(default)]
-    usage_metadata: Option<GoogleUsageMetadata>,
+pub struct CountTokensRequest {
+    pub generate_content_request: GenerateContentRequest,
 }
 
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleCandidate {
-    #[serde(default)]
-    content: Option<GoogleContent>,
-    #[serde(default)]
-    finish_reason: Option<String>,
+pub struct CountTokensResponse {
+    pub total_tokens: u64,
 }
 
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionCall {
+    pub name: String,
+    pub args: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionResponse {
+    pub name: String,
+    pub response: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct GoogleUsageMetadata {
-    #[serde(default)]
-    prompt_token_count: u64,
-    #[serde(default)]
-    candidates_token_count: u64,
+pub struct Tool {
+    pub function_declarations: Vec<FunctionDeclaration>,
 }
 
-fn convert_request(
-    model_id: &str,
-    request: &LlmCompletionRequest,
-) -> Result<(GoogleRequest, String), String> {
-    let real_model_id =
-        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ToolConfig {
+    pub function_calling_config: FunctionCallingConfig,
+}
 
-    let supports_thinking = get_model_supports_thinking(model_id);
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct FunctionCallingConfig {
+    pub mode: FunctionCallingMode,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub allowed_function_names: Option<Vec<String>>,
+}
 
-    let mut contents: Vec<GoogleContent> = Vec::new();
-    let mut system_parts: Vec<GooglePart> = Vec::new();
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum FunctionCallingMode {
+    Auto,
+    Any,
+    None,
+}
 
-    for msg in &request.messages {
-        match msg.role {
-            LlmMessageRole::System => {
-                for content in &msg.content {
-                    if let LlmMessageContent::Text(text) = content {
-                        if !text.is_empty() {
-                            system_parts
-                                .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
-                        }
-                    }
-                }
-            }
-            LlmMessageRole::User => {
-                let mut parts: Vec<GooglePart> = Vec::new();
-
-                for content in &msg.content {
-                    match content {
-                        LlmMessageContent::Text(text) => {
-                            if !text.is_empty() {
-                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
-                            }
-                        }
-                        LlmMessageContent::Image(img) => {
-                            parts.push(GooglePart::InlineData(GoogleInlineDataPart {
-                                inline_data: GoogleBlob {
-                                    mime_type: "image/png".to_string(),
-                                    data: img.source.clone(),
-                                },
-                            }));
-                        }
-                        LlmMessageContent::ToolResult(result) => {
-                            let response_value = match &result.content {
-                                LlmToolResultContent::Text(t) => {
-                                    serde_json::json!({ "output": t })
-                                }
-                                LlmToolResultContent::Image(_) => {
-                                    serde_json::json!({ "output": "Tool responded with an image" })
-                                }
-                            };
-                            parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
-                                function_response: GoogleFunctionResponse {
-                                    name: result.tool_name.clone(),
-                                    response: response_value,
-                                },
-                            }));
-                        }
-                        _ => {}
-                    }
-                }
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionDeclaration {
+    pub name: String,
+    pub description: String,
+    pub parameters: serde_json::Value,
+}
 
-                if !parts.is_empty() {
-                    contents.push(GoogleContent {
-                        parts,
-                        role: Some("user".to_string()),
-                    });
-                }
-            }
-            LlmMessageRole::Assistant => {
-                let mut parts: Vec<GooglePart> = Vec::new();
-
-                for content in &msg.content {
-                    match content {
-                        LlmMessageContent::Text(text) => {
-                            if !text.is_empty() {
-                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
-                            }
-                        }
-                        LlmMessageContent::ToolUse(tool_use) => {
-                            let thought_signature =
-                                tool_use.thought_signature.clone().filter(|s| !s.is_empty());
-
-                            let args: serde_json::Value =
-                                serde_json::from_str(&tool_use.input).unwrap_or_default();
-
-                            parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
-                                function_call: GoogleFunctionCall {
-                                    name: tool_use.name.clone(),
-                                    args,
-                                },
-                                thought_signature,
-                            }));
-                        }
-                        LlmMessageContent::Thinking(thinking) => {
-                            if let Some(ref signature) = thinking.signature {
-                                if !signature.is_empty() {
-                                    parts.push(GooglePart::Thought(GoogleThoughtPart {
-                                        thought: true,
-                                        thought_signature: signature.clone(),
-                                    }));
-                                }
-                            }
-                        }
-                        _ => {}
-                    }
-                }
+#[derive(Debug, Default)]
+pub struct ModelName {
+    pub model_id: String,
+}
 
-                if !parts.is_empty() {
-                    contents.push(GoogleContent {
-                        parts,
-                        role: Some("model".to_string()),
-                    });
-                }
-            }
-        }
+impl ModelName {
+    pub fn is_empty(&self) -> bool {
+        self.model_id.is_empty()
     }
+}
 
-    let system_instruction = if system_parts.is_empty() {
-        None
-    } else {
-        Some(GoogleSystemInstruction {
-            parts: system_parts,
-        })
-    };
+const MODEL_NAME_PREFIX: &str = "models/";
 
-    let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
-        None
-    } else {
-        let declarations: Vec<GoogleFunctionDeclaration> = request
-            .tools
-            .iter()
-            .map(|t| {
-                let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
-                    .unwrap_or(serde_json::Value::Object(Default::default()));
-                adapt_schema_for_google(&mut parameters);
-                GoogleFunctionDeclaration {
-                    name: t.name.clone(),
-                    description: t.description.clone(),
-                    parameters,
-                }
+impl Serialize for ModelName {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
+    }
+}
+
+impl<'de> Deserialize<'de> for ModelName {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let string = String::deserialize(deserializer)?;
+        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
+            Ok(Self {
+                model_id: id.to_string(),
             })
-            .collect();
-        Some(vec![GoogleTool {
-            function_declarations: declarations,
-        }])
-    };
-
-    let tool_config = request.tool_choice.as_ref().map(|tc| {
-        let mode = match tc {
-            LlmToolChoice::Auto => "AUTO",
-            LlmToolChoice::Any => "ANY",
-            LlmToolChoice::None => "NONE",
-        };
-        GoogleToolConfig {
-            function_calling_config: GoogleFunctionCallingConfig {
-                mode: mode.to_string(),
-                allowed_function_names: None,
-            },
+        } else {
+            Err(serde::de::Error::custom(format!(
+                "Expected model name to begin with {}, got: {}",
+                MODEL_NAME_PREFIX, string
+            )))
         }
-    });
+    }
+}
 
-    let thinking_config = if supports_thinking && request.thinking_allowed {
-        Some(GoogleThinkingConfig {
-            thinking_budget: 8192,
-        })
-    } else {
-        None
-    };
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
+pub enum Model {
+    #[serde(
+        rename = "gemini-2.5-flash-lite",
+        alias = "gemini-2.5-flash-lite-preview-06-17",
+        alias = "gemini-2.0-flash-lite-preview"
+    )]
+    Gemini25FlashLite,
+    #[serde(
+        rename = "gemini-2.5-flash",
+        alias = "gemini-2.0-flash-thinking-exp",
+        alias = "gemini-2.5-flash-preview-04-17",
+        alias = "gemini-2.5-flash-preview-05-20",
+        alias = "gemini-2.5-flash-preview-latest",
+        alias = "gemini-2.0-flash"
+    )]
+    #[default]
+    Gemini25Flash,
+    #[serde(
+        rename = "gemini-2.5-pro",
+        alias = "gemini-2.0-pro-exp",
+        alias = "gemini-2.5-pro-preview-latest",
+        alias = "gemini-2.5-pro-exp-03-25",
+        alias = "gemini-2.5-pro-preview-03-25",
+        alias = "gemini-2.5-pro-preview-05-06",
+        alias = "gemini-2.5-pro-preview-06-05"
+    )]
+    Gemini25Pro,
+    #[serde(rename = "gemini-3-pro-preview")]
+    Gemini3Pro,
+    #[serde(rename = "custom")]
+    Custom {
+        name: String,
+        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+        display_name: Option<String>,
+        max_tokens: u64,
+        #[serde(default)]
+        mode: GoogleModelMode,
+    },
+}
 
-    let generation_config = Some(GoogleGenerationConfig {
-        candidate_count: Some(1),
-        stop_sequences: if request.stop_sequences.is_empty() {
-            None
-        } else {
-            Some(request.stop_sequences.clone())
-        },
-        max_output_tokens: None,
-        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
-        thinking_config,
-    });
-
-    Ok((
-        GoogleRequest {
-            contents,
-            system_instruction,
-            generation_config,
-            tools,
-            tool_config,
-        },
-        real_model_id.to_string(),
-    ))
-}
-
-fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
-    let trimmed = line.trim();
-    if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
-        return None;
+impl Model {
+    pub fn default_fast() -> Self {
+        Self::Gemini25FlashLite
     }
 
-    let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
-    let json_str = json_str.trim_start_matches(',').trim();
-
-    if json_str.is_empty() {
-        return None;
+    pub fn id(&self) -> &str {
+        match self {
+            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
+            Self::Gemini25Flash => "gemini-2.5-flash",
+            Self::Gemini25Pro => "gemini-2.5-pro",
+            Self::Gemini3Pro => "gemini-3-pro-preview",
+            Self::Custom { name, .. } => name,
+        }
+    }
+    pub fn request_id(&self) -> &str {
+        match self {
+            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
+            Self::Gemini25Flash => "gemini-2.5-flash",
+            Self::Gemini25Pro => "gemini-2.5-pro",
+            Self::Gemini3Pro => "gemini-3-pro-preview",
+            Self::Custom { name, .. } => name,
+        }
     }
 
-    serde_json::from_str(json_str).ok()
-}
+    pub fn display_name(&self) -> &str {
+        match self {
+            Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
+            Self::Gemini25Flash => "Gemini 2.5 Flash",
+            Self::Gemini25Pro => "Gemini 2.5 Pro",
+            Self::Gemini3Pro => "Gemini 3 Pro",
+            Self::Custom {
+                name, display_name, ..
+            } => display_name.as_ref().unwrap_or(name),
+        }
+    }
 
-impl zed::Extension for GoogleAiProvider {
-    fn new() -> Self {
-        Self {
-            streams: Mutex::new(HashMap::new()),
-            next_stream_id: Mutex::new(0),
+    pub fn max_token_count(&self) -> u64 {
+        match self {
+            Self::Gemini25FlashLite => 1_048_576,
+            Self::Gemini25Flash => 1_048_576,
+            Self::Gemini25Pro => 1_048_576,
+            Self::Gemini3Pro => 1_048_576,
+            Self::Custom { max_tokens, .. } => *max_tokens,
         }
     }
 
-    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
-        vec![LlmProviderInfo {
-            id: "google-ai".into(),
-            name: "Google AI".into(),
-            icon: Some("icons/google-ai.svg".into()),
-        }]
+    pub fn max_output_tokens(&self) -> Option<u64> {
+        match self {
+            Model::Gemini25FlashLite => Some(65_536),
+            Model::Gemini25Flash => Some(65_536),
+            Model::Gemini25Pro => Some(65_536),
+            Model::Gemini3Pro => Some(65_536),
+            Model::Custom { .. } => None,
+        }
     }
 
-    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
-        Ok(MODELS
-            .iter()
-            .map(|m| LlmModelInfo {
-                id: m.display_name.to_string(),
-                name: m.display_name.to_string(),
-                max_token_count: m.max_tokens,
-                max_output_tokens: m.max_output_tokens,
-                capabilities: LlmModelCapabilities {
-                    supports_images: m.supports_images,
-                    supports_tools: true,
-                    supports_tool_choice_auto: true,
-                    supports_tool_choice_any: true,
-                    supports_tool_choice_none: true,
-                    supports_thinking: m.supports_thinking,
-                    tool_input_format: LlmToolInputFormat::JsonSchema,
-                },
-                is_default: m.is_default,
-                is_default_fast: m.is_default_fast,
-            })
-            .collect())
+    pub fn supports_tools(&self) -> bool {
+        true
     }
 
-    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
-        llm_get_credential("google-ai").is_some()
+    pub fn supports_images(&self) -> bool {
+        true
     }
 
-    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
-        Some(
-            "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(),
-        )
+    pub fn mode(&self) -> GoogleModelMode {
+        match self {
+            Self::Gemini25FlashLite
+            | Self::Gemini25Flash
+            | Self::Gemini25Pro
+            | Self::Gemini3Pro => {
+                GoogleModelMode::Thinking {
+                    // By default these models are set to "auto", so we preserve that behavior
+                    // but indicate they are capable of thinking mode
+                    budget_tokens: None,
+                }
+            }
+            Self::Custom { mode, .. } => *mode,
+        }
     }
+}
 
-    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
-        llm_delete_credential("google-ai")
+impl std::fmt::Display for Model {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", self.id())
     }
+}
 
-    fn llm_stream_completion_start(
-        &mut self,
-        _provider_id: &str,
-        model_id: &str,
-        request: &LlmCompletionRequest,
-    ) -> Result<String, String> {
-        let api_key = llm_get_credential("google-ai").ok_or_else(|| {
-            "No API key configured. Please add your Google AI API key in settings.".to_string()
-        })?;
-
-        let (google_request, real_model_id) = convert_request(model_id, request)?;
-
-        let body = serde_json::to_vec(&google_request)
-            .map_err(|e| format!("Failed to serialize request: {}", e))?;
-
-        let url = format!(
-            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
-            real_model_id, api_key
-        );
-
-        let http_request = HttpRequest {
-            method: HttpMethod::Post,
-            url,
-            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
-            body: Some(body),
-            redirect_policy: RedirectPolicy::FollowAll,
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use serde_json::json;
+
+    #[test]
+    fn test_function_call_part_with_signature_serializes_correctly() {
+        let part = FunctionCallPart {
+            function_call: FunctionCall {
+                name: "test_function".to_string(),
+                args: json!({"arg": "value"}),
+            },
+            thought_signature: Some("test_signature".to_string()),
         };
 
-        let response_stream = http_request
-            .fetch_stream()
-            .map_err(|e| format!("HTTP request failed: {}", e))?;
+        let serialized = serde_json::to_value(&part).unwrap();
 
-        let stream_id = {
-            let mut id_counter = self.next_stream_id.lock().unwrap();
-            let id = format!("google-ai-stream-{}", *id_counter);
-            *id_counter += 1;
-            id
-        };
+        assert_eq!(serialized["functionCall"]["name"], "test_function");
+        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
+        assert_eq!(serialized["thoughtSignature"], "test_signature");
+    }
 
-        self.streams.lock().unwrap().insert(
-            stream_id.clone(),
-            StreamState {
-                response_stream: Some(response_stream),
-                buffer: String::new(),
-                started: false,
-                stop_reason: None,
-                wants_tool_use: false,
+    #[test]
+    fn test_function_call_part_without_signature_omits_field() {
+        let part = FunctionCallPart {
+            function_call: FunctionCall {
+                name: "test_function".to_string(),
+                args: json!({"arg": "value"}),
             },
-        );
+            thought_signature: None,
+        };
+
+        let serialized = serde_json::to_value(&part).unwrap();
 
-        Ok(stream_id)
+        assert_eq!(serialized["functionCall"]["name"], "test_function");
+        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
+        // thoughtSignature field should not be present when None
+        assert!(serialized.get("thoughtSignature").is_none());
     }
 
-    fn llm_stream_completion_next(
-        &mut self,
-        stream_id: &str,
-    ) -> Result<Option<LlmCompletionEvent>, String> {
-        let mut streams = self.streams.lock().unwrap();
-        let state = streams
-            .get_mut(stream_id)
-            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
-
-        if !state.started {
-            state.started = true;
-            return Ok(Some(LlmCompletionEvent::Started));
-        }
+    #[test]
+    fn test_function_call_part_deserializes_with_signature() {
+        let json = json!({
+            "functionCall": {
+                "name": "test_function",
+                "args": {"arg": "value"}
+            },
+            "thoughtSignature": "test_signature"
+        });
 
-        let response_stream = state
-            .response_stream
-            .as_mut()
-            .ok_or_else(|| "Stream already closed".to_string())?;
-
-        loop {
-            if let Some(newline_pos) = state.buffer.find('\n') {
-                let line = state.buffer[..newline_pos].to_string();
-                state.buffer = state.buffer[newline_pos + 1..].to_string();
-
-                if let Some(response) = parse_stream_line(&line) {
-                    for candidate in response.candidates {
-                        if let Some(finish_reason) = &candidate.finish_reason {
-                            state.stop_reason = Some(match finish_reason.as_str() {
-                                "STOP" => {
-                                    if state.wants_tool_use {
-                                        LlmStopReason::ToolUse
-                                    } else {
-                                        LlmStopReason::EndTurn
-                                    }
-                                }
-                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
-                                "SAFETY" => LlmStopReason::Refusal,
-                                _ => LlmStopReason::EndTurn,
-                            });
-                        }
+        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
 
-                        if let Some(content) = candidate.content {
-                            for part in content.parts {
-                                match part {
-                                    GooglePart::Text(text_part) => {
-                                        if !text_part.text.is_empty() {
-                                            return Ok(Some(LlmCompletionEvent::Text(
-                                                text_part.text,
-                                            )));
-                                        }
-                                    }
-                                    GooglePart::FunctionCall(fc_part) => {
-                                        state.wants_tool_use = true;
-                                        let next_tool_id =
-                                            TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
-                                        let id = format!(
-                                            "{}-{}",
-                                            fc_part.function_call.name, next_tool_id
-                                        );
-
-                                        let thought_signature =
-                                            fc_part.thought_signature.filter(|s| !s.is_empty());
-
-                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
-                                            id,
-                                            name: fc_part.function_call.name,
-                                            input: fc_part.function_call.args.to_string(),
-                                            is_input_complete: true,
-                                            thought_signature,
-                                        })));
-                                    }
-                                    GooglePart::Thought(thought_part) => {
-                                        return Ok(Some(LlmCompletionEvent::Thinking(
-                                            LlmThinkingContent {
-                                                text: "(Encrypted thought)".to_string(),
-                                                signature: Some(thought_part.thought_signature),
-                                            },
-                                        )));
-                                    }
-                                    _ => {}
-                                }
-                            }
-                        }
-                    }
-
-                    if let Some(usage) = response.usage_metadata {
-                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
-                            input_tokens: usage.prompt_token_count,
-                            output_tokens: usage.candidates_token_count,
-                            cache_creation_input_tokens: None,
-                            cache_read_input_tokens: None,
-                        })));
-                    }
-                }
+        assert_eq!(part.function_call.name, "test_function");
+        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
+    }
 
-                continue;
+    #[test]
+    fn test_function_call_part_deserializes_without_signature() {
+        let json = json!({
+            "functionCall": {
+                "name": "test_function",
+                "args": {"arg": "value"}
             }
+        });
 
-            match response_stream.next_chunk() {
-                Ok(Some(chunk)) => {
-                    let text = String::from_utf8_lossy(&chunk);
-                    state.buffer.push_str(&text);
-                }
-                Ok(None) => {
-                    // Stream ended - check if we have a stop reason
-                    if let Some(stop_reason) = state.stop_reason.take() {
-                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
-                    }
+        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
 
-                    // No stop reason - this is unexpected. Check if buffer contains error info
-                    let mut error_msg = String::from("Stream ended unexpectedly.");
+        assert_eq!(part.function_call.name, "test_function");
+        assert_eq!(part.thought_signature, None);
+    }
 
-                    // Try to parse remaining buffer as potential error response
-                    if !state.buffer.is_empty() {
-                        error_msg.push_str(&format!(
-                            "\nRemaining buffer: {}",
-                            &state.buffer[..state.buffer.len().min(1000)]
-                        ));
-                    }
+    #[test]
+    fn test_function_call_part_round_trip() {
+        let original = FunctionCallPart {
+            function_call: FunctionCall {
+                name: "test_function".to_string(),
+                args: json!({"arg": "value", "nested": {"key": "val"}}),
+            },
+            thought_signature: Some("round_trip_signature".to_string()),
+        };
 
-                    return Err(error_msg);
-                }
-                Err(e) => {
-                    return Err(format!("Stream error: {}", e));
-                }
-            }
-        }
+        let serialized = serde_json::to_value(&original).unwrap();
+        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
+
+        assert_eq!(deserialized.function_call.name, original.function_call.name);
+        assert_eq!(deserialized.function_call.args, original.function_call.args);
+        assert_eq!(deserialized.thought_signature, original.thought_signature);
     }
 
-    fn llm_stream_completion_close(&mut self, stream_id: &str) {
-        self.streams.lock().unwrap().remove(stream_id);
+    #[test]
+    fn test_function_call_part_with_empty_signature_serializes() {
+        let part = FunctionCallPart {
+            function_call: FunctionCall {
+                name: "test_function".to_string(),
+                args: json!({"arg": "value"}),
+            },
+            thought_signature: Some("".to_string()),
+        };
+
+        let serialized = serde_json::to_value(&part).unwrap();
+
+        // Empty string should still be serialized (normalization happens at a higher level)
+        assert_eq!(serialized["thoughtSignature"], "");
     }
 }
-
-zed::register_extension!(GoogleAiProvider);