Revert "Replace extensions google_ai with the hardcoded one."

Richard Feldman created

This reverts commit 6f05a4b6dfe9a58c0f62fe437dfdc8fddbdf0065.

Change summary

extensions/google-ai/src/google_ai.rs | 1253 +++++++++++++++-------------
1 file changed, 667 insertions(+), 586 deletions(-)

Detailed changes

extensions/google-ai/src/google_ai.rs 🔗

@@ -1,717 +1,798 @@
-use std::mem;
-
-use anyhow::{Result, anyhow, bail};
-use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
-pub use settings::ModelMode as GoogleModelMode;
-
-pub const API_URL: &str = "https://generativelanguage.googleapis.com";
-
-pub async fn stream_generate_content(
-    client: &dyn HttpClient,
-    api_url: &str,
-    api_key: &str,
-    mut request: GenerateContentRequest,
-) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
-    let api_key = api_key.trim();
-    validate_generate_content_request(&request)?;
-
-    // The `model` field is emptied as it is provided as a path parameter.
-    let model_id = mem::take(&mut request.model.model_id);
-
-    let uri =
-        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
-
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(uri)
-        .header("Content-Type", "application/json");
-
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
-    let mut response = client.send(request).await?;
-    if response.status().is_success() {
-        let reader = BufReader::new(response.into_body());
-        Ok(reader
-            .lines()
-            .filter_map(|line| async move {
-                match line {
-                    Ok(line) => {
-                        if let Some(line) = line.strip_prefix("data: ") {
-                            match serde_json::from_str(line) {
-                                Ok(response) => Some(Ok(response)),
-                                Err(error) => Some(Err(anyhow!(format!(
-                                    "Error parsing JSON: {error:?}\n{line:?}"
-                                )))),
-                            }
-                        } else {
-                            None
-                        }
-                    }
-                    Err(error) => Some(Err(anyhow!(error))),
-                }
-            })
-            .boxed())
-    } else {
-        let mut text = String::new();
-        response.body_mut().read_to_string(&mut text).await?;
-        Err(anyhow!(
-            "error during streamGenerateContent, status code: {:?}, body: {}",
-            response.status(),
-            text
-        ))
-    }
-}
+use std::collections::HashMap;
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::Mutex;
+
+use serde::{Deserialize, Serialize};
+use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
+use zed_extension_api::{self as zed, *};
+
+static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
+
+struct GoogleAiProvider {
+    streams: Mutex<HashMap<String, StreamState>>,
+    next_stream_id: Mutex<u64>,
+}
+
+struct StreamState {
+    response_stream: Option<HttpResponseStream>,
+    buffer: String,
+    started: bool,
+    stop_reason: Option<LlmStopReason>,
+    wants_tool_use: bool,
+}
+
+struct ModelDefinition {
+    real_id: &'static str,
+    display_name: &'static str,
+    max_tokens: u64,
+    max_output_tokens: Option<u64>,
+    supports_images: bool,
+    supports_thinking: bool,
+    is_default: bool,
+    is_default_fast: bool,
+}
+
+const MODELS: &[ModelDefinition] = &[
+    ModelDefinition {
+        real_id: "gemini-2.5-flash-lite",
+        display_name: "Gemini 2.5 Flash-Lite",
+        max_tokens: 1_048_576,
+        max_output_tokens: Some(65_536),
+        supports_images: true,
+        supports_thinking: true,
+        is_default: false,
+        is_default_fast: true,
+    },
+    ModelDefinition {
+        real_id: "gemini-2.5-flash",
+        display_name: "Gemini 2.5 Flash",
+        max_tokens: 1_048_576,
+        max_output_tokens: Some(65_536),
+        supports_images: true,
+        supports_thinking: true,
+        is_default: true,
+        is_default_fast: false,
+    },
+    ModelDefinition {
+        real_id: "gemini-2.5-pro",
+        display_name: "Gemini 2.5 Pro",
+        max_tokens: 1_048_576,
+        max_output_tokens: Some(65_536),
+        supports_images: true,
+        supports_thinking: true,
+        is_default: false,
+        is_default_fast: false,
+    },
+    ModelDefinition {
+        real_id: "gemini-3-pro-preview",
+        display_name: "Gemini 3 Pro",
+        max_tokens: 1_048_576,
+        max_output_tokens: Some(65_536),
+        supports_images: true,
+        supports_thinking: true,
+        is_default: false,
+        is_default_fast: false,
+    },
+];
 
-pub async fn count_tokens(
-    client: &dyn HttpClient,
-    api_url: &str,
-    api_key: &str,
-    request: CountTokensRequest,
-) -> Result<CountTokensResponse> {
-    validate_generate_content_request(&request.generate_content_request)?;
-
-    let uri = format!(
-        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
-        model_id = &request.generate_content_request.model.model_id,
-    );
-
-    let request = serde_json::to_string(&request)?;
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(&uri)
-        .header("Content-Type", "application/json");
-    let http_request = request_builder.body(AsyncBody::from(request))?;
-
-    let mut response = client.send(http_request).await?;
-    let mut text = String::new();
-    response.body_mut().read_to_string(&mut text).await?;
-    anyhow::ensure!(
-        response.status().is_success(),
-        "error during countTokens, status code: {:?}, body: {}",
-        response.status(),
-        text
-    );
-    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+fn get_real_model_id(display_name: &str) -> Option<&'static str> {
+    MODELS
+        .iter()
+        .find(|m| m.display_name == display_name)
+        .map(|m| m.real_id)
 }
 
-pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
-    if request.model.is_empty() {
-        bail!("Model must be specified");
-    }
-
-    if request.contents.is_empty() {
-        bail!("Request must contain at least one content item");
-    }
-
-    if let Some(user_content) = request
-        .contents
+fn get_model_supports_thinking(display_name: &str) -> bool {
+    MODELS
         .iter()
-        .find(|content| content.role == Role::User)
-        && user_content.parts.is_empty()
-    {
-        bail!("User content must contain at least one part");
-    }
-
-    Ok(())
-}
+        .find(|m| m.display_name == display_name)
+        .map(|m| m.supports_thinking)
+        .unwrap_or(false)
+}
+
+/// Adapts a JSON schema to be compatible with Google's API subset.
+/// Google only supports a specific subset of JSON Schema fields.
+/// See: https://ai.google.dev/api/caching#Schema
+fn adapt_schema_for_google(json: &mut serde_json::Value) {
+    adapt_schema_for_google_impl(json, true);
+}
+
+fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
+    if let serde_json::Value::Object(obj) = json {
+        // Google's Schema only supports these fields:
+        // type, format, title, description, nullable, enum, maxItems, minItems,
+        // properties, required, minProperties, maxProperties, minLength, maxLength,
+        // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
+        const ALLOWED_KEYS: &[&str] = &[
+            "type",
+            "format",
+            "title",
+            "description",
+            "nullable",
+            "enum",
+            "maxItems",
+            "minItems",
+            "properties",
+            "required",
+            "minProperties",
+            "maxProperties",
+            "minLength",
+            "maxLength",
+            "pattern",
+            "example",
+            "anyOf",
+            "propertyOrdering",
+            "default",
+            "items",
+            "minimum",
+            "maximum",
+        ];
+
+        // Convert oneOf to anyOf before filtering keys
+        if let Some(one_of) = obj.remove("oneOf") {
+            obj.insert("anyOf".to_string(), one_of);
+        }
 
-#[derive(Debug, Serialize, Deserialize)]
-pub enum Task {
-    #[serde(rename = "generateContent")]
-    GenerateContent,
-    #[serde(rename = "streamGenerateContent")]
-    StreamGenerateContent,
-    #[serde(rename = "countTokens")]
-    CountTokens,
-    #[serde(rename = "embedContent")]
-    EmbedContent,
-    #[serde(rename = "batchEmbedContents")]
-    BatchEmbedContents,
-}
+        // If type is an array (e.g., ["string", "null"]), take just the first type
+        if let Some(type_field) = obj.get_mut("type") {
+            if let serde_json::Value::Array(types) = type_field {
+                if let Some(first_type) = types.first().cloned() {
+                    *type_field = first_type;
+                }
+            }
+        }
 
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct GenerateContentRequest {
-    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
-    pub model: ModelName,
-    pub contents: Vec<Content>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub system_instruction: Option<SystemInstruction>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub generation_config: Option<GenerationConfig>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub safety_settings: Option<Vec<SafetySetting>>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub tools: Option<Vec<Tool>>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub tool_config: Option<ToolConfig>,
-}
+        // Only filter keys if this is a schema object, not a properties map
+        if is_schema {
+            obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
+        }
 
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct GenerateContentResponse {
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub candidates: Option<Vec<GenerateContentCandidate>>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub prompt_feedback: Option<PromptFeedback>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub usage_metadata: Option<UsageMetadata>,
+        // Recursively process nested values
+        // "properties" contains a map of property names -> schemas
+        // "items" and "anyOf" contain schemas directly
+        for (key, value) in obj.iter_mut() {
+            if key == "properties" {
+                // properties is a map of property_name -> schema
+                if let serde_json::Value::Object(props) = value {
+                    for (_, prop_schema) in props.iter_mut() {
+                        adapt_schema_for_google_impl(prop_schema, true);
+                    }
+                }
+            } else if key == "items" {
+                // items is a schema
+                adapt_schema_for_google_impl(value, true);
+            } else if key == "anyOf" {
+                // anyOf is an array of schemas
+                if let serde_json::Value::Array(arr) = value {
+                    for item in arr.iter_mut() {
+                        adapt_schema_for_google_impl(item, true);
+                    }
+                }
+            }
+        }
+    } else if let serde_json::Value::Array(arr) = json {
+        for item in arr.iter_mut() {
+            adapt_schema_for_google_impl(item, true);
+        }
+    }
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct GenerateContentCandidate {
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub index: Option<usize>,
-    pub content: Content,
+struct GoogleRequest {
+    contents: Vec<GoogleContent>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub finish_reason: Option<String>,
+    system_instruction: Option<GoogleSystemInstruction>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub finish_message: Option<String>,
+    generation_config: Option<GoogleGenerationConfig>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub safety_ratings: Option<Vec<SafetyRating>>,
+    tools: Option<Vec<GoogleTool>>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub citation_metadata: Option<CitationMetadata>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct Content {
-    #[serde(default)]
-    pub parts: Vec<Part>,
-    pub role: Role,
+    tool_config: Option<GoogleToolConfig>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct SystemInstruction {
-    pub parts: Vec<Part>,
+struct GoogleSystemInstruction {
+    parts: Vec<GooglePart>,
 }
 
-#[derive(Debug, PartialEq, Deserialize, Serialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub enum Role {
-    User,
-    Model,
+struct GoogleContent {
+    parts: Vec<GooglePart>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    role: Option<String>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(untagged)]
-pub enum Part {
-    TextPart(TextPart),
-    InlineDataPart(InlineDataPart),
-    FunctionCallPart(FunctionCallPart),
-    FunctionResponsePart(FunctionResponsePart),
-    ThoughtPart(ThoughtPart),
+enum GooglePart {
+    Text(GoogleTextPart),
+    InlineData(GoogleInlineDataPart),
+    FunctionCall(GoogleFunctionCallPart),
+    FunctionResponse(GoogleFunctionResponsePart),
+    Thought(GoogleThoughtPart),
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct TextPart {
-    pub text: String,
+struct GoogleTextPart {
+    text: String,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct InlineDataPart {
-    pub inline_data: GenerativeContentBlob,
+struct GoogleInlineDataPart {
+    inline_data: GoogleBlob,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct GenerativeContentBlob {
-    pub mime_type: String,
-    pub data: String,
+struct GoogleBlob {
+    mime_type: String,
+    data: String,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct FunctionCallPart {
-    pub function_call: FunctionCall,
-    /// Thought signature returned by the model for function calls.
-    /// Only present on the first function call in parallel call scenarios.
+struct GoogleFunctionCallPart {
+    function_call: GoogleFunctionCall,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub thought_signature: Option<String>,
+    thought_signature: Option<String>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct FunctionResponsePart {
-    pub function_response: FunctionResponse,
+struct GoogleFunctionCall {
+    name: String,
+    args: serde_json::Value,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct ThoughtPart {
-    pub thought: bool,
-    pub thought_signature: String,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct CitationSource {
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub start_index: Option<usize>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub end_index: Option<usize>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub uri: Option<String>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub license: Option<String>,
+struct GoogleFunctionResponsePart {
+    function_response: GoogleFunctionResponse,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct CitationMetadata {
-    pub citation_sources: Vec<CitationSource>,
+struct GoogleFunctionResponse {
+    name: String,
+    response: serde_json::Value,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "camelCase")]
-pub struct PromptFeedback {
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub block_reason: Option<String>,
-    pub safety_ratings: Option<Vec<SafetyRating>>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub block_reason_message: Option<String>,
+struct GoogleThoughtPart {
+    thought: bool,
+    thought_signature: String,
 }
 
-#[derive(Debug, Serialize, Deserialize, Default)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct UsageMetadata {
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub prompt_token_count: Option<u64>,
+struct GoogleGenerationConfig {
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub cached_content_token_count: Option<u64>,
+    candidate_count: Option<usize>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub candidates_token_count: Option<u64>,
+    stop_sequences: Option<Vec<String>>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub tool_use_prompt_token_count: Option<u64>,
+    max_output_tokens: Option<usize>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub thoughts_token_count: Option<u64>,
+    temperature: Option<f64>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub total_token_count: Option<u64>,
+    thinking_config: Option<GoogleThinkingConfig>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct ThinkingConfig {
-    pub thinking_budget: u32,
-}
-
-#[derive(Debug, Deserialize, Serialize)]
-#[serde(rename_all = "camelCase")]
-pub struct GenerationConfig {
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub candidate_count: Option<usize>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub stop_sequences: Option<Vec<String>>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub max_output_tokens: Option<usize>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub temperature: Option<f64>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub top_p: Option<f64>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub top_k: Option<usize>,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub thinking_config: Option<ThinkingConfig>,
+struct GoogleThinkingConfig {
+    thinking_budget: u32,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct SafetySetting {
-    pub category: HarmCategory,
-    pub threshold: HarmBlockThreshold,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub enum HarmCategory {
-    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
-    Unspecified,
-    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
-    Derogatory,
-    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
-    Toxicity,
-    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
-    Violence,
-    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
-    Sexual,
-    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
-    Medical,
-    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
-    Dangerous,
-    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
-    Harassment,
-    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
-    HateSpeech,
-    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
-    SexuallyExplicit,
-    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
-    DangerousContent,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum HarmBlockThreshold {
-    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
-    Unspecified,
-    BlockLowAndAbove,
-    BlockMediumAndAbove,
-    BlockOnlyHigh,
-    BlockNone,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum HarmProbability {
-    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
-    Unspecified,
-    Negligible,
-    Low,
-    Medium,
-    High,
+struct GoogleTool {
+    function_declarations: Vec<GoogleFunctionDeclaration>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct SafetyRating {
-    pub category: HarmCategory,
-    pub probability: HarmProbability,
+struct GoogleFunctionDeclaration {
+    name: String,
+    description: String,
+    parameters: serde_json::Value,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct CountTokensRequest {
-    pub generate_content_request: GenerateContentRequest,
+struct GoogleToolConfig {
+    function_calling_config: GoogleFunctionCallingConfig,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
 #[serde(rename_all = "camelCase")]
-pub struct CountTokensResponse {
-    pub total_tokens: u64,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct FunctionCall {
-    pub name: String,
-    pub args: serde_json::Value,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct FunctionResponse {
-    pub name: String,
-    pub response: serde_json::Value,
+struct GoogleFunctionCallingConfig {
+    mode: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    allowed_function_names: Option<Vec<String>>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Deserialize, Debug)]
 #[serde(rename_all = "camelCase")]
-pub struct Tool {
-    pub function_declarations: Vec<FunctionDeclaration>,
+struct GoogleStreamResponse {
+    #[serde(default)]
+    candidates: Vec<GoogleCandidate>,
+    #[serde(default)]
+    usage_metadata: Option<GoogleUsageMetadata>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Deserialize, Debug)]
 #[serde(rename_all = "camelCase")]
-pub struct ToolConfig {
-    pub function_calling_config: FunctionCallingConfig,
+struct GoogleCandidate {
+    #[serde(default)]
+    content: Option<GoogleContent>,
+    #[serde(default)]
+    finish_reason: Option<String>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Deserialize, Debug)]
 #[serde(rename_all = "camelCase")]
-pub struct FunctionCallingConfig {
-    pub mode: FunctionCallingMode,
-    #[serde(skip_serializing_if = "Option::is_none")]
-    pub allowed_function_names: Option<Vec<String>>,
+struct GoogleUsageMetadata {
+    #[serde(default)]
+    prompt_token_count: u64,
+    #[serde(default)]
+    candidates_token_count: u64,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "lowercase")]
-pub enum FunctionCallingMode {
-    Auto,
-    Any,
-    None,
-}
+fn convert_request(
+    model_id: &str,
+    request: &LlmCompletionRequest,
+) -> Result<(GoogleRequest, String), String> {
+    let real_model_id =
+        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
 
-#[derive(Debug, Serialize, Deserialize)]
-pub struct FunctionDeclaration {
-    pub name: String,
-    pub description: String,
-    pub parameters: serde_json::Value,
-}
+    let supports_thinking = get_model_supports_thinking(model_id);
 
-#[derive(Debug, Default)]
-pub struct ModelName {
-    pub model_id: String,
-}
+    let mut contents: Vec<GoogleContent> = Vec::new();
+    let mut system_parts: Vec<GooglePart> = Vec::new();
 
-impl ModelName {
-    pub fn is_empty(&self) -> bool {
-        self.model_id.is_empty()
-    }
-}
+    for msg in &request.messages {
+        match msg.role {
+            LlmMessageRole::System => {
+                for content in &msg.content {
+                    if let LlmMessageContent::Text(text) = content {
+                        if !text.is_empty() {
+                            system_parts
+                                .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
+                        }
+                    }
+                }
+            }
+            LlmMessageRole::User => {
+                let mut parts: Vec<GooglePart> = Vec::new();
+
+                for content in &msg.content {
+                    match content {
+                        LlmMessageContent::Text(text) => {
+                            if !text.is_empty() {
+                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
+                            }
+                        }
+                        LlmMessageContent::Image(img) => {
+                            parts.push(GooglePart::InlineData(GoogleInlineDataPart {
+                                inline_data: GoogleBlob {
+                                    mime_type: "image/png".to_string(),
+                                    data: img.source.clone(),
+                                },
+                            }));
+                        }
+                        LlmMessageContent::ToolResult(result) => {
+                            let response_value = match &result.content {
+                                LlmToolResultContent::Text(t) => {
+                                    serde_json::json!({ "output": t })
+                                }
+                                LlmToolResultContent::Image(_) => {
+                                    serde_json::json!({ "output": "Tool responded with an image" })
+                                }
+                            };
+                            parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
+                                function_response: GoogleFunctionResponse {
+                                    name: result.tool_name.clone(),
+                                    response: response_value,
+                                },
+                            }));
+                        }
+                        _ => {}
+                    }
+                }
 
-const MODEL_NAME_PREFIX: &str = "models/";
+                if !parts.is_empty() {
+                    contents.push(GoogleContent {
+                        parts,
+                        role: Some("user".to_string()),
+                    });
+                }
+            }
+            LlmMessageRole::Assistant => {
+                let mut parts: Vec<GooglePart> = Vec::new();
+
+                for content in &msg.content {
+                    match content {
+                        LlmMessageContent::Text(text) => {
+                            if !text.is_empty() {
+                                parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
+                            }
+                        }
+                        LlmMessageContent::ToolUse(tool_use) => {
+                            let thought_signature =
+                                tool_use.thought_signature.clone().filter(|s| !s.is_empty());
+
+                            let args: serde_json::Value =
+                                serde_json::from_str(&tool_use.input).unwrap_or_default();
+
+                            parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
+                                function_call: GoogleFunctionCall {
+                                    name: tool_use.name.clone(),
+                                    args,
+                                },
+                                thought_signature,
+                            }));
+                        }
+                        LlmMessageContent::Thinking(thinking) => {
+                            if let Some(ref signature) = thinking.signature {
+                                if !signature.is_empty() {
+                                    parts.push(GooglePart::Thought(GoogleThoughtPart {
+                                        thought: true,
+                                        thought_signature: signature.clone(),
+                                    }));
+                                }
+                            }
+                        }
+                        _ => {}
+                    }
+                }
 
-impl Serialize for ModelName {
-    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-    {
-        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
+                if !parts.is_empty() {
+                    contents.push(GoogleContent {
+                        parts,
+                        role: Some("model".to_string()),
+                    });
+                }
+            }
+        }
     }
-}
 
-impl<'de> Deserialize<'de> for ModelName {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: Deserializer<'de>,
-    {
-        let string = String::deserialize(deserializer)?;
-        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
-            Ok(Self {
-                model_id: id.to_string(),
+    let system_instruction = if system_parts.is_empty() {
+        None
+    } else {
+        Some(GoogleSystemInstruction {
+            parts: system_parts,
+        })
+    };
+
+    let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
+        None
+    } else {
+        let declarations: Vec<GoogleFunctionDeclaration> = request
+            .tools
+            .iter()
+            .map(|t| {
+                let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
+                    .unwrap_or(serde_json::Value::Object(Default::default()));
+                adapt_schema_for_google(&mut parameters);
+                GoogleFunctionDeclaration {
+                    name: t.name.clone(),
+                    description: t.description.clone(),
+                    parameters,
+                }
             })
-        } else {
-            Err(serde::de::Error::custom(format!(
-                "Expected model name to begin with {}, got: {}",
-                MODEL_NAME_PREFIX, string
-            )))
+            .collect();
+        Some(vec![GoogleTool {
+            function_declarations: declarations,
+        }])
+    };
+
+    let tool_config = request.tool_choice.as_ref().map(|tc| {
+        let mode = match tc {
+            LlmToolChoice::Auto => "AUTO",
+            LlmToolChoice::Any => "ANY",
+            LlmToolChoice::None => "NONE",
+        };
+        GoogleToolConfig {
+            function_calling_config: GoogleFunctionCallingConfig {
+                mode: mode.to_string(),
+                allowed_function_names: None,
+            },
         }
-    }
-}
+    });
 
-#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
-pub enum Model {
-    #[serde(
-        rename = "gemini-2.5-flash-lite",
-        alias = "gemini-2.5-flash-lite-preview-06-17",
-        alias = "gemini-2.0-flash-lite-preview"
-    )]
-    Gemini25FlashLite,
-    #[serde(
-        rename = "gemini-2.5-flash",
-        alias = "gemini-2.0-flash-thinking-exp",
-        alias = "gemini-2.5-flash-preview-04-17",
-        alias = "gemini-2.5-flash-preview-05-20",
-        alias = "gemini-2.5-flash-preview-latest",
-        alias = "gemini-2.0-flash"
-    )]
-    #[default]
-    Gemini25Flash,
-    #[serde(
-        rename = "gemini-2.5-pro",
-        alias = "gemini-2.0-pro-exp",
-        alias = "gemini-2.5-pro-preview-latest",
-        alias = "gemini-2.5-pro-exp-03-25",
-        alias = "gemini-2.5-pro-preview-03-25",
-        alias = "gemini-2.5-pro-preview-05-06",
-        alias = "gemini-2.5-pro-preview-06-05"
-    )]
-    Gemini25Pro,
-    #[serde(rename = "gemini-3-pro-preview")]
-    Gemini3Pro,
-    #[serde(rename = "custom")]
-    Custom {
-        name: String,
-        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
-        display_name: Option<String>,
-        max_tokens: u64,
-        #[serde(default)]
-        mode: GoogleModelMode,
-    },
-}
+    let thinking_config = if supports_thinking && request.thinking_allowed {
+        Some(GoogleThinkingConfig {
+            thinking_budget: 8192,
+        })
+    } else {
+        None
+    };
 
-impl Model {
-    pub fn default_fast() -> Self {
-        Self::Gemini25FlashLite
+    let generation_config = Some(GoogleGenerationConfig {
+        candidate_count: Some(1),
+        stop_sequences: if request.stop_sequences.is_empty() {
+            None
+        } else {
+            Some(request.stop_sequences.clone())
+        },
+        max_output_tokens: None,
+        temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+        thinking_config,
+    });
+
+    Ok((
+        GoogleRequest {
+            contents,
+            system_instruction,
+            generation_config,
+            tools,
+            tool_config,
+        },
+        real_model_id.to_string(),
+    ))
+}
+
+fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
+    let trimmed = line.trim();
+    if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
+        return None;
     }
 
-    pub fn id(&self) -> &str {
-        match self {
-            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
-            Self::Gemini25Flash => "gemini-2.5-flash",
-            Self::Gemini25Pro => "gemini-2.5-pro",
-            Self::Gemini3Pro => "gemini-3-pro-preview",
-            Self::Custom { name, .. } => name,
-        }
-    }
-    pub fn request_id(&self) -> &str {
-        match self {
-            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
-            Self::Gemini25Flash => "gemini-2.5-flash",
-            Self::Gemini25Pro => "gemini-2.5-pro",
-            Self::Gemini3Pro => "gemini-3-pro-preview",
-            Self::Custom { name, .. } => name,
-        }
-    }
+    let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
+    let json_str = json_str.trim_start_matches(',').trim();
 
-    pub fn display_name(&self) -> &str {
-        match self {
-            Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
-            Self::Gemini25Flash => "Gemini 2.5 Flash",
-            Self::Gemini25Pro => "Gemini 2.5 Pro",
-            Self::Gemini3Pro => "Gemini 3 Pro",
-            Self::Custom {
-                name, display_name, ..
-            } => display_name.as_ref().unwrap_or(name),
-        }
+    if json_str.is_empty() {
+        return None;
     }
 
-    pub fn max_token_count(&self) -> u64 {
-        match self {
-            Self::Gemini25FlashLite => 1_048_576,
-            Self::Gemini25Flash => 1_048_576,
-            Self::Gemini25Pro => 1_048_576,
-            Self::Gemini3Pro => 1_048_576,
-            Self::Custom { max_tokens, .. } => *max_tokens,
+    serde_json::from_str(json_str).ok()
+}
+
+impl zed::Extension for GoogleAiProvider {
+    fn new() -> Self {
+        Self {
+            streams: Mutex::new(HashMap::new()),
+            next_stream_id: Mutex::new(0),
         }
     }
 
-    pub fn max_output_tokens(&self) -> Option<u64> {
-        match self {
-            Model::Gemini25FlashLite => Some(65_536),
-            Model::Gemini25Flash => Some(65_536),
-            Model::Gemini25Pro => Some(65_536),
-            Model::Gemini3Pro => Some(65_536),
-            Model::Custom { .. } => None,
-        }
+    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
+        vec![LlmProviderInfo {
+            id: "google-ai".into(),
+            name: "Google AI".into(),
+            icon: Some("icons/google-ai.svg".into()),
+        }]
     }
 
-    pub fn supports_tools(&self) -> bool {
-        true
+    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
+        Ok(MODELS
+            .iter()
+            .map(|m| LlmModelInfo {
+                id: m.display_name.to_string(),
+                name: m.display_name.to_string(),
+                max_token_count: m.max_tokens,
+                max_output_tokens: m.max_output_tokens,
+                capabilities: LlmModelCapabilities {
+                    supports_images: m.supports_images,
+                    supports_tools: true,
+                    supports_tool_choice_auto: true,
+                    supports_tool_choice_any: true,
+                    supports_tool_choice_none: true,
+                    supports_thinking: m.supports_thinking,
+                    tool_input_format: LlmToolInputFormat::JsonSchema,
+                },
+                is_default: m.is_default,
+                is_default_fast: m.is_default_fast,
+            })
+            .collect())
     }
 
-    pub fn supports_images(&self) -> bool {
-        true
+    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
+        llm_get_credential("google-ai").is_some()
     }
 
-    pub fn mode(&self) -> GoogleModelMode {
-        match self {
-            Self::Gemini25FlashLite
-            | Self::Gemini25Flash
-            | Self::Gemini25Pro
-            | Self::Gemini3Pro => {
-                GoogleModelMode::Thinking {
-                    // By default these models are set to "auto", so we preserve that behavior
-                    // but indicate they are capable of thinking mode
-                    budget_tokens: None,
-                }
-            }
-            Self::Custom { mode, .. } => *mode,
-        }
+    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
+        Some(
+            "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(),
+        )
     }
-}
 
-impl std::fmt::Display for Model {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(f, "{}", self.id())
+    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
+        llm_delete_credential("google-ai")
     }
-}
 
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use serde_json::json;
-
-    #[test]
-    fn test_function_call_part_with_signature_serializes_correctly() {
-        let part = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value"}),
-            },
-            thought_signature: Some("test_signature".to_string()),
+    fn llm_stream_completion_start(
+        &mut self,
+        _provider_id: &str,
+        model_id: &str,
+        request: &LlmCompletionRequest,
+    ) -> Result<String, String> {
+        let api_key = llm_get_credential("google-ai").ok_or_else(|| {
+            "No API key configured. Please add your Google AI API key in settings.".to_string()
+        })?;
+
+        let (google_request, real_model_id) = convert_request(model_id, request)?;
+
+        let body = serde_json::to_vec(&google_request)
+            .map_err(|e| format!("Failed to serialize request: {}", e))?;
+
+        let url = format!(
+            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
+            real_model_id, api_key
+        );
+
+        let http_request = HttpRequest {
+            method: HttpMethod::Post,
+            url,
+            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
+            body: Some(body),
+            redirect_policy: RedirectPolicy::FollowAll,
         };
 
-        let serialized = serde_json::to_value(&part).unwrap();
-
-        assert_eq!(serialized["functionCall"]["name"], "test_function");
-        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
-        assert_eq!(serialized["thoughtSignature"], "test_signature");
-    }
+        let response_stream = http_request
+            .fetch_stream()
+            .map_err(|e| format!("HTTP request failed: {}", e))?;
 
-    #[test]
-    fn test_function_call_part_without_signature_omits_field() {
-        let part = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value"}),
-            },
-            thought_signature: None,
+        let stream_id = {
+            let mut id_counter = self.next_stream_id.lock().unwrap();
+            let id = format!("google-ai-stream-{}", *id_counter);
+            *id_counter += 1;
+            id
         };
 
-        let serialized = serde_json::to_value(&part).unwrap();
+        self.streams.lock().unwrap().insert(
+            stream_id.clone(),
+            StreamState {
+                response_stream: Some(response_stream),
+                buffer: String::new(),
+                started: false,
+                stop_reason: None,
+                wants_tool_use: false,
+            },
+        );
 
-        assert_eq!(serialized["functionCall"]["name"], "test_function");
-        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
-        // thoughtSignature field should not be present when None
-        assert!(serialized.get("thoughtSignature").is_none());
+        Ok(stream_id)
     }
 
-    #[test]
-    fn test_function_call_part_deserializes_with_signature() {
-        let json = json!({
-            "functionCall": {
-                "name": "test_function",
-                "args": {"arg": "value"}
-            },
-            "thoughtSignature": "test_signature"
-        });
+    fn llm_stream_completion_next(
+        &mut self,
+        stream_id: &str,
+    ) -> Result<Option<LlmCompletionEvent>, String> {
+        let mut streams = self.streams.lock().unwrap();
+        let state = streams
+            .get_mut(stream_id)
+            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
+
+        if !state.started {
+            state.started = true;
+            return Ok(Some(LlmCompletionEvent::Started));
+        }
 
-        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
+        let response_stream = state
+            .response_stream
+            .as_mut()
+            .ok_or_else(|| "Stream already closed".to_string())?;
+
+        loop {
+            if let Some(newline_pos) = state.buffer.find('\n') {
+                let line = state.buffer[..newline_pos].to_string();
+                state.buffer = state.buffer[newline_pos + 1..].to_string();
+
+                if let Some(response) = parse_stream_line(&line) {
+                    for candidate in response.candidates {
+                        if let Some(finish_reason) = &candidate.finish_reason {
+                            state.stop_reason = Some(match finish_reason.as_str() {
+                                "STOP" => {
+                                    if state.wants_tool_use {
+                                        LlmStopReason::ToolUse
+                                    } else {
+                                        LlmStopReason::EndTurn
+                                    }
+                                }
+                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
+                                "SAFETY" => LlmStopReason::Refusal,
+                                _ => LlmStopReason::EndTurn,
+                            });
+                        }
 
-        assert_eq!(part.function_call.name, "test_function");
-        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
-    }
+                        if let Some(content) = candidate.content {
+                            for part in content.parts {
+                                match part {
+                                    GooglePart::Text(text_part) => {
+                                        if !text_part.text.is_empty() {
+                                            return Ok(Some(LlmCompletionEvent::Text(
+                                                text_part.text,
+                                            )));
+                                        }
+                                    }
+                                    GooglePart::FunctionCall(fc_part) => {
+                                        state.wants_tool_use = true;
+                                        let next_tool_id =
+                                            TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
+                                        let id = format!(
+                                            "{}-{}",
+                                            fc_part.function_call.name, next_tool_id
+                                        );
+
+                                        let thought_signature =
+                                            fc_part.thought_signature.filter(|s| !s.is_empty());
+
+                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
+                                            id,
+                                            name: fc_part.function_call.name,
+                                            input: fc_part.function_call.args.to_string(),
+                                            is_input_complete: true,
+                                            thought_signature,
+                                        })));
+                                    }
+                                    GooglePart::Thought(thought_part) => {
+                                        return Ok(Some(LlmCompletionEvent::Thinking(
+                                            LlmThinkingContent {
+                                                text: "(Encrypted thought)".to_string(),
+                                                signature: Some(thought_part.thought_signature),
+                                            },
+                                        )));
+                                    }
+                                    _ => {}
+                                }
+                            }
+                        }
+                    }
 
-    #[test]
-    fn test_function_call_part_deserializes_without_signature() {
-        let json = json!({
-            "functionCall": {
-                "name": "test_function",
-                "args": {"arg": "value"}
-            }
-        });
+                    if let Some(usage) = response.usage_metadata {
+                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
+                            input_tokens: usage.prompt_token_count,
+                            output_tokens: usage.candidates_token_count,
+                            cache_creation_input_tokens: None,
+                            cache_read_input_tokens: None,
+                        })));
+                    }
+                }
 
-        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
+                continue;
+            }
 
-        assert_eq!(part.function_call.name, "test_function");
-        assert_eq!(part.thought_signature, None);
-    }
+            match response_stream.next_chunk() {
+                Ok(Some(chunk)) => {
+                    let text = String::from_utf8_lossy(&chunk);
+                    state.buffer.push_str(&text);
+                }
+                Ok(None) => {
+                    // Stream ended - check if we have a stop reason
+                    if let Some(stop_reason) = state.stop_reason.take() {
+                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
+                    }
 
-    #[test]
-    fn test_function_call_part_round_trip() {
-        let original = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value", "nested": {"key": "val"}}),
-            },
-            thought_signature: Some("round_trip_signature".to_string()),
-        };
+                    // No stop reason - this is unexpected. Check if buffer contains error info
+                    let mut error_msg = String::from("Stream ended unexpectedly.");
 
-        let serialized = serde_json::to_value(&original).unwrap();
-        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
+                    // Try to parse remaining buffer as potential error response
+                    if !state.buffer.is_empty() {
+                        error_msg.push_str(&format!(
+                            "\nRemaining buffer: {}",
+                            &state.buffer[..state.buffer.len().min(1000)]
+                        ));
+                    }
 
-        assert_eq!(deserialized.function_call.name, original.function_call.name);
-        assert_eq!(deserialized.function_call.args, original.function_call.args);
-        assert_eq!(deserialized.thought_signature, original.thought_signature);
+                    return Err(error_msg);
+                }
+                Err(e) => {
+                    return Err(format!("Stream error: {}", e));
+                }
+            }
+        }
     }
 
-    #[test]
-    fn test_function_call_part_with_empty_signature_serializes() {
-        let part = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value"}),
-            },
-            thought_signature: Some("".to_string()),
-        };
-
-        let serialized = serde_json::to_value(&part).unwrap();
-
-        // Empty string should still be serialized (normalization happens at a higher level)
-        assert_eq!(serialized["thoughtSignature"], "");
+    fn llm_stream_completion_close(&mut self, stream_id: &str) {
+        self.streams.lock().unwrap().remove(stream_id);
     }
 }
+
+zed::register_extension!(GoogleAiProvider);