Fix extensions/google-ai/src/google_ai.rs

Richard Feldman and Mikayla Maki created

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>

Change summary

extensions/google-ai/src/google_ai.rs | 882 +++++++++++++++++-----------
1 file changed, 545 insertions(+), 337 deletions(-)

Detailed changes

extensions/google-ai/src/google_ai.rs 🔗

@@ -1,136 +1,579 @@
+use std::collections::HashMap;
 
-use std::mem;
-
-use anyhow::{Result, anyhow, bail};
-use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use serde::{Deserialize, Deserializer, Serialize, Serializer};
-pub use settings::ModelMode as GoogleModelMode;
+use zed_extension_api::{
+    self as zed, http_client::HttpMethod, http_client::HttpRequest, llm_get_env_var,
+    LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmMessageContent,
+    LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, LlmStopReason,
+    LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
+};
 
 pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 
-pub async fn stream_generate_content(
-    client: &dyn HttpClient,
-    api_url: &str,
-    api_key: &str,
-    mut request: GenerateContentRequest,
-) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
-    let api_key = api_key.trim();
-    validate_generate_content_request(&request)?;
-
-    // The `model` field is emptied as it is provided as a path parameter.
-    let model_id = mem::take(&mut request.model.model_id);
-
-    let uri =
-        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
-
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(uri)
-        .header("Content-Type", "application/json");
-
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
-    let mut response = client.send(request).await?;
-    if response.status().is_success() {
-        let reader = BufReader::new(response.into_body());
-        Ok(reader
-            .lines()
-            .filter_map(|line| async move {
-                match line {
-                    Ok(line) => {
-                        if let Some(line) = line.strip_prefix("data: ") {
-                            match serde_json::from_str(line) {
-                                Ok(response) => Some(Ok(response)),
-                                Err(error) => Some(Err(anyhow!(format!(
-                                    "Error parsing JSON: {error:?}\n{line:?}"
-                                )))),
-                            }
-                        } else {
-                            None
-                        }
-                    }
-                    Err(error) => Some(Err(anyhow!(error))),
-                }
-            })
-            .boxed())
-    } else {
-        let mut text = String::new();
-        response.body_mut().read_to_string(&mut text).await?;
-        Err(anyhow!(
-            "error during streamGenerateContent, status code: {:?}, body: {}",
-            response.status(),
-            text
-        ))
-    }
-}
+fn stream_generate_content(
+    model_id: &str,
+    request: &LlmCompletionRequest,
+    streams: &mut HashMap<String, StreamState>,
+    next_stream_id: &mut u64,
+) -> Result<String, String> {
+    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
 
-pub async fn count_tokens(
-    client: &dyn HttpClient,
-    api_url: &str,
-    api_key: &str,
-    request: CountTokensRequest,
-) -> Result<CountTokensResponse> {
-    validate_generate_content_request(&request.generate_content_request)?;
+    let generate_content_request = build_generate_content_request(model_id, request)?;
+    validate_generate_content_request(&generate_content_request)?;
 
     let uri = format!(
-        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
-        model_id = &request.generate_content_request.model.model_id,
+        "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
+        API_URL, model_id, api_key
+    );
+
+    let body = serde_json::to_vec(&generate_content_request)
+        .map_err(|e| format!("Failed to serialize request: {}", e))?;
+
+    let http_request = HttpRequest::builder()
+        .method(HttpMethod::Post)
+        .url(&uri)
+        .header("Content-Type", "application/json")
+        .body(body)
+        .build()?;
+
+    let response_stream = http_request.fetch_stream()?;
+
+    let stream_id = format!("stream-{}", *next_stream_id);
+    *next_stream_id += 1;
+
+    streams.insert(
+        stream_id.clone(),
+        StreamState {
+            response_stream,
+            buffer: String::new(),
+            usage: None,
+        },
     );
 
-    let request = serde_json::to_string(&request)?;
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(&uri)
-        .header("Content-Type", "application/json");
-    let http_request = request_builder.body(AsyncBody::from(request))?;
-
-    let mut response = client.send(http_request).await?;
-    let mut text = String::new();
-    response.body_mut().read_to_string(&mut text).await?;
-    anyhow::ensure!(
-        response.status().is_success(),
-        "error during countTokens, status code: {:?}, body: {}",
-        response.status(),
-        text
+    Ok(stream_id)
+}
+
+fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
+    let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
+
+    let generate_content_request = build_generate_content_request(model_id, request)?;
+    validate_generate_content_request(&generate_content_request)?;
+    let count_request = CountTokensRequest {
+        generate_content_request,
+    };
+
+    let uri = format!(
+        "{}/v1beta/models/{}:countTokens?key={}",
+        API_URL, model_id, api_key
     );
-    Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+
+    let body = serde_json::to_vec(&count_request)
+        .map_err(|e| format!("Failed to serialize request: {}", e))?;
+
+    let http_request = HttpRequest::builder()
+        .method(HttpMethod::Post)
+        .url(&uri)
+        .header("Content-Type", "application/json")
+        .body(body)
+        .build()?;
+
+    let response = http_request.fetch()?;
+    let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
+        .map_err(|e| format!("Failed to parse response: {}", e))?;
+
+    Ok(response_body.total_tokens)
 }
 
-pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
+fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
     if request.model.is_empty() {
-        bail!("Model must be specified");
+        return Err("Model must be specified".to_string());
     }
 
     if request.contents.is_empty() {
-        bail!("Request must contain at least one content item");
+        return Err("Request must contain at least one content item".to_string());
     }
 
     if let Some(user_content) = request
         .contents
         .iter()
         .find(|content| content.role == Role::User)
-        && user_content.parts.is_empty()
     {
-        bail!("User content must contain at least one part");
+        if user_content.parts.is_empty() {
+            return Err("User content must contain at least one part".to_string());
+        }
     }
 
     Ok(())
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-pub enum Task {
-    #[serde(rename = "generateContent")]
-    GenerateContent,
-    #[serde(rename = "streamGenerateContent")]
-    StreamGenerateContent,
-    #[serde(rename = "countTokens")]
-    CountTokens,
-    #[serde(rename = "embedContent")]
-    EmbedContent,
-    #[serde(rename = "batchEmbedContents")]
-    BatchEmbedContents,
+// Extension implementation
+
+const PROVIDER_ID: &str = "google-ai";
+const PROVIDER_NAME: &str = "Google AI";
+
+struct GoogleAiExtension {
+    streams: HashMap<String, StreamState>,
+    next_stream_id: u64,
+}
+
+struct StreamState {
+    response_stream: zed::http_client::HttpResponseStream,
+    buffer: String,
+    usage: Option<UsageMetadata>,
+}
+
+impl zed::Extension for GoogleAiExtension {
+    fn new() -> Self {
+        Self {
+            streams: HashMap::new(),
+            next_stream_id: 0,
+        }
+    }
+
+    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
+        vec![LlmProviderInfo {
+            id: PROVIDER_ID.to_string(),
+            name: PROVIDER_NAME.to_string(),
+            icon: Some("icons/google-ai.svg".to_string()),
+        }]
+    }
+
+    fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
+        if provider_id != PROVIDER_ID {
+            return Err(format!("Unknown provider: {}", provider_id));
+        }
+        Ok(get_models())
+    }
+
+    fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
+        if provider_id != PROVIDER_ID {
+            return None;
+        }
+
+        Some(
+            r#"## Google AI Setup
+
+To use Google AI models in Zed, you need a Gemini API key.
+
+1. Go to [Google AI Studio](https://aistudio.google.com/apikey)
+2. Create or select a project
+3. Generate an API key
+4. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
+
+You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
+"#
+            .to_string(),
+        )
+    }
+
+    fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
+        if provider_id != PROVIDER_ID {
+            return false;
+        }
+        get_api_key().is_some()
+    }
+
+    fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
+        if provider_id != PROVIDER_ID {
+            return Err(format!("Unknown provider: {}", provider_id));
+        }
+        Ok(())
+    }
+
+    fn llm_count_tokens(
+        &self,
+        provider_id: &str,
+        model_id: &str,
+        request: &LlmCompletionRequest,
+    ) -> Result<u64, String> {
+        if provider_id != PROVIDER_ID {
+            return Err(format!("Unknown provider: {}", provider_id));
+        }
+        count_tokens(model_id, request)
+    }
+
+    fn llm_stream_completion_start(
+        &mut self,
+        provider_id: &str,
+        model_id: &str,
+        request: &LlmCompletionRequest,
+    ) -> Result<String, String> {
+        if provider_id != PROVIDER_ID {
+            return Err(format!("Unknown provider: {}", provider_id));
+        }
+        stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
+    }
+
+    fn llm_stream_completion_next(
+        &mut self,
+        stream_id: &str,
+    ) -> Result<Option<LlmCompletionEvent>, String> {
+        stream_generate_content_next(stream_id, &mut self.streams)
+    }
+
+    fn llm_stream_completion_close(&mut self, stream_id: &str) {
+        self.streams.remove(stream_id);
+    }
+
+    fn llm_cache_configuration(
+        &self,
+        provider_id: &str,
+        _model_id: &str,
+    ) -> Option<LlmCacheConfiguration> {
+        if provider_id != PROVIDER_ID {
+            return None;
+        }
+
+        Some(LlmCacheConfiguration {
+            max_cache_anchors: 1,
+            should_cache_tool_definitions: false,
+            min_total_token_count: 32768,
+        })
+    }
+}
+
+zed::register_extension!(GoogleAiExtension);
+
+// Helper functions
+
+fn get_api_key() -> Option<String> {
+    llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
+}
+
+fn get_models() -> Vec<LlmModelInfo> {
+    vec![
+        LlmModelInfo {
+            id: "gemini-2.5-flash-lite".to_string(),
+            name: "Gemini 2.5 Flash-Lite".to_string(),
+            max_token_count: 1_048_576,
+            max_output_tokens: Some(65_536),
+            capabilities: LlmModelCapabilities {
+                supports_images: true,
+                supports_tools: true,
+                supports_tool_choice_auto: true,
+                supports_tool_choice_any: true,
+                supports_tool_choice_none: true,
+                supports_thinking: true,
+                tool_input_format: LlmToolInputFormat::JsonSchema,
+            },
+            is_default: false,
+            is_default_fast: true,
+        },
+        LlmModelInfo {
+            id: "gemini-2.5-flash".to_string(),
+            name: "Gemini 2.5 Flash".to_string(),
+            max_token_count: 1_048_576,
+            max_output_tokens: Some(65_536),
+            capabilities: LlmModelCapabilities {
+                supports_images: true,
+                supports_tools: true,
+                supports_tool_choice_auto: true,
+                supports_tool_choice_any: true,
+                supports_tool_choice_none: true,
+                supports_thinking: true,
+                tool_input_format: LlmToolInputFormat::JsonSchema,
+            },
+            is_default: true,
+            is_default_fast: false,
+        },
+        LlmModelInfo {
+            id: "gemini-2.5-pro".to_string(),
+            name: "Gemini 2.5 Pro".to_string(),
+            max_token_count: 1_048_576,
+            max_output_tokens: Some(65_536),
+            capabilities: LlmModelCapabilities {
+                supports_images: true,
+                supports_tools: true,
+                supports_tool_choice_auto: true,
+                supports_tool_choice_any: true,
+                supports_tool_choice_none: true,
+                supports_thinking: true,
+                tool_input_format: LlmToolInputFormat::JsonSchema,
+            },
+            is_default: false,
+            is_default_fast: false,
+        },
+        LlmModelInfo {
+            id: "gemini-3-pro-preview".to_string(),
+            name: "Gemini 3 Pro".to_string(),
+            max_token_count: 1_048_576,
+            max_output_tokens: Some(65_536),
+            capabilities: LlmModelCapabilities {
+                supports_images: true,
+                supports_tools: true,
+                supports_tool_choice_auto: true,
+                supports_tool_choice_any: true,
+                supports_tool_choice_none: true,
+                supports_thinking: true,
+                tool_input_format: LlmToolInputFormat::JsonSchema,
+            },
+            is_default: false,
+            is_default_fast: false,
+        },
+        LlmModelInfo {
+            id: "gemini-3-flash-preview".to_string(),
+            name: "Gemini 3 Flash".to_string(),
+            max_token_count: 1_048_576,
+            max_output_tokens: Some(65_536),
+            capabilities: LlmModelCapabilities {
+                supports_images: true,
+                supports_tools: true,
+                supports_tool_choice_auto: true,
+                supports_tool_choice_any: true,
+                supports_tool_choice_none: true,
+                supports_thinking: true,
+                tool_input_format: LlmToolInputFormat::JsonSchema,
+            },
+            is_default: false,
+            is_default_fast: false,
+        },
+    ]
+}
+
+fn stream_generate_content_next(
+    stream_id: &str,
+    streams: &mut HashMap<String, StreamState>,
+) -> Result<Option<LlmCompletionEvent>, String> {
+    let state = streams
+        .get_mut(stream_id)
+        .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
+
+    loop {
+        if let Some(newline_pos) = state.buffer.find('\n') {
+            let line = state.buffer[..newline_pos].to_string();
+            state.buffer = state.buffer[newline_pos + 1..].to_string();
+
+            if let Some(data) = line.strip_prefix("data: ") {
+                if data.trim().is_empty() {
+                    continue;
+                }
+
+                let response: GenerateContentResponse = serde_json::from_str(data)
+                    .map_err(|e| format!("Failed to parse SSE data: {} - {}", e, data))?;
+
+                if let Some(usage) = response.usage_metadata {
+                    state.usage = Some(usage);
+                }
+
+                if let Some(candidates) = response.candidates {
+                    for candidate in candidates {
+                        for part in candidate.content.parts {
+                            match part {
+                                Part::TextPart(text_part) => {
+                                    return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
+                                }
+                                Part::ThoughtPart(thought_part) => {
+                                    return Ok(Some(LlmCompletionEvent::Thinking(
+                                        LlmThinkingContent {
+                                            text: String::new(),
+                                            signature: Some(thought_part.thought_signature),
+                                        },
+                                    )));
+                                }
+                                Part::FunctionCallPart(fc_part) => {
+                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
+                                        id: fc_part.function_call.name.clone(),
+                                        name: fc_part.function_call.name,
+                                        input: serde_json::to_string(&fc_part.function_call.args)
+                                            .unwrap_or_default(),
+                                        is_input_complete: true,
+                                        thought_signature: fc_part.thought_signature,
+                                    })));
+                                }
+                                _ => {}
+                            }
+                        }
+
+                        if let Some(finish_reason) = candidate.finish_reason {
+                            let stop_reason = match finish_reason.as_str() {
+                                "STOP" => LlmStopReason::EndTurn,
+                                "MAX_TOKENS" => LlmStopReason::MaxTokens,
+                                "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
+                                "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
+                                _ => LlmStopReason::EndTurn,
+                            };
+
+                            if let Some(usage) = state.usage.take() {
+                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
+                                    input_tokens: usage.prompt_token_count.unwrap_or(0),
+                                    output_tokens: usage.candidates_token_count.unwrap_or(0),
+                                    cache_creation_input_tokens: None,
+                                    cache_read_input_tokens: usage.cached_content_token_count,
+                                })));
+                            }
+
+                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
+                        }
+                    }
+                }
+            }
+
+            continue;
+        }
+
+        match state.response_stream.next_chunk() {
+            Ok(Some(chunk)) => {
+                let chunk_str = String::from_utf8_lossy(&chunk);
+                state.buffer.push_str(&chunk_str);
+            }
+            Ok(None) => {
+                streams.remove(stream_id);
+                return Ok(None);
+            }
+            Err(e) => {
+                streams.remove(stream_id);
+                return Err(e);
+            }
+        }
+    }
+}
+
+fn build_generate_content_request(
+    model_id: &str,
+    request: &LlmCompletionRequest,
+) -> Result<GenerateContentRequest, String> {
+    let mut contents: Vec<Content> = Vec::new();
+    let mut system_instruction: Option<SystemInstruction> = None;
+
+    for message in &request.messages {
+        match message.role {
+            LlmMessageRole::System => {
+                let parts = convert_content_to_parts(&message.content)?;
+                system_instruction = Some(SystemInstruction { parts });
+            }
+            LlmMessageRole::User | LlmMessageRole::Assistant => {
+                let role = match message.role {
+                    LlmMessageRole::User => Role::User,
+                    LlmMessageRole::Assistant => Role::Model,
+                    _ => continue,
+                };
+                let parts = convert_content_to_parts(&message.content)?;
+                contents.push(Content { parts, role });
+            }
+        }
+    }
+
+    let tools = if !request.tools.is_empty() {
+        Some(vec![Tool {
+            function_declarations: request
+                .tools
+                .iter()
+                .map(|t| FunctionDeclaration {
+                    name: t.name.clone(),
+                    description: t.description.clone(),
+                    parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
+                })
+                .collect(),
+        }])
+    } else {
+        None
+    };
+
+    let tool_config = request.tool_choice.as_ref().map(|choice| {
+        let mode = match choice {
+            zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
+            zed::LlmToolChoice::Any => FunctionCallingMode::Any,
+            zed::LlmToolChoice::None => FunctionCallingMode::None,
+        };
+        ToolConfig {
+            function_calling_config: FunctionCallingConfig {
+                mode,
+                allowed_function_names: None,
+            },
+        }
+    });
+
+    let generation_config = Some(GenerationConfig {
+        candidate_count: Some(1),
+        stop_sequences: if request.stop_sequences.is_empty() {
+            None
+        } else {
+            Some(request.stop_sequences.clone())
+        },
+        max_output_tokens: request.max_tokens.map(|t| t as usize),
+        temperature: request.temperature.map(|t| t as f64),
+        top_p: None,
+        top_k: None,
+        thinking_config: if request.thinking_allowed {
+            Some(ThinkingConfig {
+                thinking_budget: 8192,
+            })
+        } else {
+            None
+        },
+    });
+
+    Ok(GenerateContentRequest {
+        model: ModelName {
+            model_id: model_id.to_string(),
+        },
+        contents,
+        system_instruction,
+        generation_config,
+        safety_settings: None,
+        tools,
+        tool_config,
+    })
+}
+
+fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
+    let mut parts = Vec::new();
+
+    for item in content {
+        match item {
+            LlmMessageContent::Text(text) => {
+                parts.push(Part::TextPart(TextPart { text: text.clone() }));
+            }
+            LlmMessageContent::Image(image) => {
+                parts.push(Part::InlineDataPart(InlineDataPart {
+                    inline_data: GenerativeContentBlob {
+                        mime_type: "image/png".to_string(),
+                        data: image.source.clone(),
+                    },
+                }));
+            }
+            LlmMessageContent::ToolUse(tool_use) => {
+                parts.push(Part::FunctionCallPart(FunctionCallPart {
+                    function_call: FunctionCall {
+                        name: tool_use.name.clone(),
+                        args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
+                    },
+                    thought_signature: tool_use.thought_signature.clone(),
+                }));
+            }
+            LlmMessageContent::ToolResult(tool_result) => {
+                let response_value = match &tool_result.content {
+                    zed::LlmToolResultContent::Text(text) => {
+                        serde_json::json!({ "result": text })
+                    }
+                    zed::LlmToolResultContent::Image(_) => {
+                        serde_json::json!({ "error": "Image results not supported" })
+                    }
+                };
+                parts.push(Part::FunctionResponsePart(FunctionResponsePart {
+                    function_response: FunctionResponse {
+                        name: tool_result.tool_name.clone(),
+                        response: response_value,
+                    },
+                }));
+            }
+            LlmMessageContent::Thinking(thinking) => {
+                if let Some(signature) = &thinking.signature {
+                    parts.push(Part::ThoughtPart(ThoughtPart {
+                        thought: true,
+                        thought_signature: signature.clone(),
+                    }));
+                }
+            }
+            LlmMessageContent::RedactedThinking(_) => {}
+        }
+    }
+
+    Ok(parts)
 }
 
+// Data structures for Google AI API
+
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct GenerateContentRequest {
@@ -481,238 +924,3 @@ impl<'de> Deserialize<'de> for ModelName {
         }
     }
 }
-
-#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
-pub enum Model {
-    #[serde(
-        rename = "gemini-2.5-flash-lite",
-        alias = "gemini-2.5-flash-lite-preview-06-17",
-        alias = "gemini-2.0-flash-lite-preview"
-    )]
-    Gemini25FlashLite,
-    #[serde(
-        rename = "gemini-2.5-flash",
-        alias = "gemini-2.0-flash-thinking-exp",
-        alias = "gemini-2.5-flash-preview-04-17",
-        alias = "gemini-2.5-flash-preview-05-20",
-        alias = "gemini-2.5-flash-preview-latest",
-        alias = "gemini-2.0-flash"
-    )]
-    #[default]
-    Gemini25Flash,
-    #[serde(
-        rename = "gemini-2.5-pro",
-        alias = "gemini-2.0-pro-exp",
-        alias = "gemini-2.5-pro-preview-latest",
-        alias = "gemini-2.5-pro-exp-03-25",
-        alias = "gemini-2.5-pro-preview-03-25",
-        alias = "gemini-2.5-pro-preview-05-06",
-        alias = "gemini-2.5-pro-preview-06-05"
-    )]
-    Gemini25Pro,
-    #[serde(rename = "gemini-3-pro-preview")]
-    Gemini3Pro,
-    #[serde(rename = "custom")]
-    Custom {
-        name: String,
-        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
-        display_name: Option<String>,
-        max_tokens: u64,
-        #[serde(default)]
-        mode: GoogleModelMode,
-    },
-}
-
-impl Model {
-    pub fn default_fast() -> Self {
-        Self::Gemini25FlashLite
-    }
-
-    pub fn id(&self) -> &str {
-        match self {
-            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
-            Self::Gemini25Flash => "gemini-2.5-flash",
-            Self::Gemini25Pro => "gemini-2.5-pro",
-            Self::Gemini3Pro => "gemini-3-pro-preview",
-            Self::Custom { name, .. } => name,
-        }
-    }
-    pub fn request_id(&self) -> &str {
-        match self {
-            Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
-            Self::Gemini25Flash => "gemini-2.5-flash",
-            Self::Gemini25Pro => "gemini-2.5-pro",
-            Self::Gemini3Pro => "gemini-3-pro-preview",
-            Self::Custom { name, .. } => name,
-        }
-    }
-
-    pub fn display_name(&self) -> &str {
-        match self {
-            Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
-            Self::Gemini25Flash => "Gemini 2.5 Flash",
-            Self::Gemini25Pro => "Gemini 2.5 Pro",
-            Self::Gemini3Pro => "Gemini 3 Pro",
-            Self::Custom {
-                name, display_name, ..
-            } => display_name.as_ref().unwrap_or(name),
-        }
-    }
-
-    pub fn max_token_count(&self) -> u64 {
-        match self {
-            Self::Gemini25FlashLite => 1_048_576,
-            Self::Gemini25Flash => 1_048_576,
-            Self::Gemini25Pro => 1_048_576,
-            Self::Gemini3Pro => 1_048_576,
-            Self::Custom { max_tokens, .. } => *max_tokens,
-        }
-    }
-
-    pub fn max_output_tokens(&self) -> Option<u64> {
-        match self {
-            Model::Gemini25FlashLite => Some(65_536),
-            Model::Gemini25Flash => Some(65_536),
-            Model::Gemini25Pro => Some(65_536),
-            Model::Gemini3Pro => Some(65_536),
-            Model::Custom { .. } => None,
-        }
-    }
-
-    pub fn supports_tools(&self) -> bool {
-        true
-    }
-
-    pub fn supports_images(&self) -> bool {
-        true
-    }
-
-    pub fn mode(&self) -> GoogleModelMode {
-        match self {
-            Self::Gemini25FlashLite
-            | Self::Gemini25Flash
-            | Self::Gemini25Pro
-            | Self::Gemini3Pro => {
-                GoogleModelMode::Thinking {
-                    // By default these models are set to "auto", so we preserve that behavior
-                    // but indicate they are capable of thinking mode
-                    budget_tokens: None,
-                }
-            }
-            Self::Custom { mode, .. } => *mode,
-        }
-    }
-}
-
-impl std::fmt::Display for Model {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(f, "{}", self.id())
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use serde_json::json;
-
-    #[test]
-    fn test_function_call_part_with_signature_serializes_correctly() {
-        let part = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value"}),
-            },
-            thought_signature: Some("test_signature".to_string()),
-        };
-
-        let serialized = serde_json::to_value(&part).unwrap();
-
-        assert_eq!(serialized["functionCall"]["name"], "test_function");
-        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
-        assert_eq!(serialized["thoughtSignature"], "test_signature");
-    }
-
-    #[test]
-    fn test_function_call_part_without_signature_omits_field() {
-        let part = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value"}),
-            },
-            thought_signature: None,
-        };
-
-        let serialized = serde_json::to_value(&part).unwrap();
-
-        assert_eq!(serialized["functionCall"]["name"], "test_function");
-        assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
-        // thoughtSignature field should not be present when None
-        assert!(serialized.get("thoughtSignature").is_none());
-    }
-
-    #[test]
-    fn test_function_call_part_deserializes_with_signature() {
-        let json = json!({
-            "functionCall": {
-                "name": "test_function",
-                "args": {"arg": "value"}
-            },
-            "thoughtSignature": "test_signature"
-        });
-
-        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
-
-        assert_eq!(part.function_call.name, "test_function");
-        assert_eq!(part.thought_signature, Some("test_signature".to_string()));
-    }
-
-    #[test]
-    fn test_function_call_part_deserializes_without_signature() {
-        let json = json!({
-            "functionCall": {
-                "name": "test_function",
-                "args": {"arg": "value"}
-            }
-        });
-
-        let part: FunctionCallPart = serde_json::from_value(json).unwrap();
-
-        assert_eq!(part.function_call.name, "test_function");
-        assert_eq!(part.thought_signature, None);
-    }
-
-    #[test]
-    fn test_function_call_part_round_trip() {
-        let original = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value", "nested": {"key": "val"}}),
-            },
-            thought_signature: Some("round_trip_signature".to_string()),
-        };
-
-        let serialized = serde_json::to_value(&original).unwrap();
-        let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
-
-        assert_eq!(deserialized.function_call.name, original.function_call.name);
-        assert_eq!(deserialized.function_call.args, original.function_call.args);
-        assert_eq!(deserialized.thought_signature, original.thought_signature);
-    }
-
-    #[test]
-    fn test_function_call_part_with_empty_signature_serializes() {
-        let part = FunctionCallPart {
-            function_call: FunctionCall {
-                name: "test_function".to_string(),
-                args: json!({"arg": "value"}),
-            },
-            thought_signature: Some("".to_string()),
-        };
-
-        let serialized = serde_json::to_value(&part).unwrap();
-
-        // Empty string should still be serialized (normalization happens at a higher level)
-        assert_eq!(serialized["thoughtSignature"], "");
-    }
-}