From 6a07fe4e996e73fd4a07d4fc42d64ce7bda85854 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Wed, 17 Dec 2025 13:40:30 -0500 Subject: [PATCH] Revert "Replace extensions google_ai with the hardcoded one." This reverts commit 6f05a4b6dfe9a58c0f62fe437dfdc8fddbdf0065. --- extensions/google-ai/src/google_ai.rs | 1253 +++++++++++++------------ 1 file changed, 667 insertions(+), 586 deletions(-) diff --git a/extensions/google-ai/src/google_ai.rs b/extensions/google-ai/src/google_ai.rs index 3eff860e16f15fae76d8f9cb2523d2b91b611125..846ad348b18bfd329c563be965a26419f8103f8b 100644 --- a/extensions/google-ai/src/google_ai.rs +++ b/extensions/google-ai/src/google_ai.rs @@ -1,717 +1,798 @@ -use std::mem; - -use anyhow::{Result, anyhow, bail}; -use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub use settings::ModelMode as GoogleModelMode; - -pub const API_URL: &str = "https://generativelanguage.googleapis.com"; - -pub async fn stream_generate_content( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - mut request: GenerateContentRequest, -) -> Result>> { - let api_key = api_key.trim(); - validate_generate_content_request(&request)?; - - // The `model` field is emptied as it is provided as a path parameter. - let model_id = mem::take(&mut request.model.model_id); - - let uri = - format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",); - - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json"); - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; - let mut response = client.send(request).await?; - if response.status().is_success() { - let reader = BufReader::new(response.into_body()); - Ok(reader - .lines() - .filter_map(|line| async move { - match line { - Ok(line) => { - if let Some(line) = line.strip_prefix("data: ") { - match serde_json::from_str(line) { - Ok(response) => Some(Ok(response)), - Err(error) => Some(Err(anyhow!(format!( - "Error parsing JSON: {error:?}\n{line:?}" - )))), - } - } else { - None - } - } - Err(error) => Some(Err(anyhow!(error))), - } - }) - .boxed()) - } else { - let mut text = String::new(); - response.body_mut().read_to_string(&mut text).await?; - Err(anyhow!( - "error during streamGenerateContent, status code: {:?}, body: {}", - response.status(), - text - )) - } -} +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Mutex; + +use serde::{Deserialize, Serialize}; +use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; +use zed_extension_api::{self as zed, *}; + +static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + +struct GoogleAiProvider { + streams: Mutex>, + next_stream_id: Mutex, +} + +struct StreamState { + response_stream: Option, + buffer: String, + started: bool, + stop_reason: Option, + wants_tool_use: bool, +} + +struct ModelDefinition { + real_id: &'static str, + display_name: &'static str, + max_tokens: u64, + max_output_tokens: Option, + supports_images: bool, + supports_thinking: bool, + is_default: bool, + is_default_fast: bool, +} + +const MODELS: &[ModelDefinition] = &[ + ModelDefinition { + real_id: "gemini-2.5-flash-lite", + display_name: "Gemini 2.5 Flash-Lite", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: true, + }, + ModelDefinition { + real_id: "gemini-2.5-flash", + display_name: "Gemini 2.5 Flash", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: true, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gemini-2.5-pro", + display_name: "Gemini 2.5 Pro", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gemini-3-pro-preview", + display_name: "Gemini 3 Pro", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, +]; -pub async fn count_tokens( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CountTokensRequest, -) -> Result { - validate_generate_content_request(&request.generate_content_request)?; - - let uri = format!( - "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", - model_id = &request.generate_content_request.model.model_id, - ); - - let request = serde_json::to_string(&request)?; - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(&uri) - .header("Content-Type", "application/json"); - let http_request = request_builder.body(AsyncBody::from(request))?; - - let mut response = client.send(http_request).await?; - let mut text = String::new(); - response.body_mut().read_to_string(&mut text).await?; - anyhow::ensure!( - response.status().is_success(), - "error during countTokens, status code: {:?}, body: {}", - response.status(), - text - ); - Ok(serde_json::from_str::(&text)?) +fn get_real_model_id(display_name: &str) -> Option<&'static str> { + MODELS + .iter() + .find(|m| m.display_name == display_name) + .map(|m| m.real_id) } -pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { - if request.model.is_empty() { - bail!("Model must be specified"); - } - - if request.contents.is_empty() { - bail!("Request must contain at least one content item"); - } - - if let Some(user_content) = request - .contents +fn get_model_supports_thinking(display_name: &str) -> bool { + MODELS .iter() - .find(|content| content.role == Role::User) - && user_content.parts.is_empty() - { - bail!("User content must contain at least one part"); - } - - Ok(()) -} + .find(|m| m.display_name == display_name) + .map(|m| m.supports_thinking) + .unwrap_or(false) +} + +/// Adapts a JSON schema to be compatible with Google's API subset. +/// Google only supports a specific subset of JSON Schema fields. +/// See: https://ai.google.dev/api/caching#Schema +fn adapt_schema_for_google(json: &mut serde_json::Value) { + adapt_schema_for_google_impl(json, true); +} + +fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) { + if let serde_json::Value::Object(obj) = json { + // Google's Schema only supports these fields: + // type, format, title, description, nullable, enum, maxItems, minItems, + // properties, required, minProperties, maxProperties, minLength, maxLength, + // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum + const ALLOWED_KEYS: &[&str] = &[ + "type", + "format", + "title", + "description", + "nullable", + "enum", + "maxItems", + "minItems", + "properties", + "required", + "minProperties", + "maxProperties", + "minLength", + "maxLength", + "pattern", + "example", + "anyOf", + "propertyOrdering", + "default", + "items", + "minimum", + "maximum", + ]; + + // Convert oneOf to anyOf before filtering keys + if let Some(one_of) = obj.remove("oneOf") { + obj.insert("anyOf".to_string(), one_of); + } -#[derive(Debug, Serialize, Deserialize)] -pub enum Task { - #[serde(rename = "generateContent")] - GenerateContent, - #[serde(rename = "streamGenerateContent")] - StreamGenerateContent, - #[serde(rename = "countTokens")] - CountTokens, - #[serde(rename = "embedContent")] - EmbedContent, - #[serde(rename = "batchEmbedContents")] - BatchEmbedContents, -} + // If type is an array (e.g., ["string", "null"]), take just the first type + if let Some(type_field) = obj.get_mut("type") { + if let serde_json::Value::Array(types) = type_field { + if let Some(first_type) = types.first().cloned() { + *type_field = first_type; + } + } + } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GenerateContentRequest { - #[serde(default, skip_serializing_if = "ModelName::is_empty")] - pub model: ModelName, - pub contents: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_instruction: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub generation_config: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub safety_settings: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_config: Option, -} + // Only filter keys if this is a schema object, not a properties map + if is_schema { + obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str())); + } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GenerateContentResponse { - #[serde(skip_serializing_if = "Option::is_none")] - pub candidates: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_feedback: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage_metadata: Option, + // Recursively process nested values + // "properties" contains a map of property names -> schemas + // "items" and "anyOf" contain schemas directly + for (key, value) in obj.iter_mut() { + if key == "properties" { + // properties is a map of property_name -> schema + if let serde_json::Value::Object(props) = value { + for (_, prop_schema) in props.iter_mut() { + adapt_schema_for_google_impl(prop_schema, true); + } + } + } else if key == "items" { + // items is a schema + adapt_schema_for_google_impl(value, true); + } else if key == "anyOf" { + // anyOf is an array of schemas + if let serde_json::Value::Array(arr) = value { + for item in arr.iter_mut() { + adapt_schema_for_google_impl(item, true); + } + } + } + } + } else if let serde_json::Value::Array(arr) = json { + for item in arr.iter_mut() { + adapt_schema_for_google_impl(item, true); + } + } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct GenerateContentCandidate { - #[serde(skip_serializing_if = "Option::is_none")] - pub index: Option, - pub content: Content, +struct GoogleRequest { + contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub finish_reason: Option, + system_instruction: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub finish_message: Option, + generation_config: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub safety_ratings: Option>, + tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub citation_metadata: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Content { - #[serde(default)] - pub parts: Vec, - pub role: Role, + tool_config: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct SystemInstruction { - pub parts: Vec, +struct GoogleSystemInstruction { + parts: Vec, } -#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub enum Role { - User, - Model, +struct GoogleContent { + parts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(untagged)] -pub enum Part { - TextPart(TextPart), - InlineDataPart(InlineDataPart), - FunctionCallPart(FunctionCallPart), - FunctionResponsePart(FunctionResponsePart), - ThoughtPart(ThoughtPart), +enum GooglePart { + Text(GoogleTextPart), + InlineData(GoogleInlineDataPart), + FunctionCall(GoogleFunctionCallPart), + FunctionResponse(GoogleFunctionResponsePart), + Thought(GoogleThoughtPart), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct TextPart { - pub text: String, +struct GoogleTextPart { + text: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct InlineDataPart { - pub inline_data: GenerativeContentBlob, +struct GoogleInlineDataPart { + inline_data: GoogleBlob, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct GenerativeContentBlob { - pub mime_type: String, - pub data: String, +struct GoogleBlob { + mime_type: String, + data: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct FunctionCallPart { - pub function_call: FunctionCall, - /// Thought signature returned by the model for function calls. - /// Only present on the first function call in parallel call scenarios. +struct GoogleFunctionCallPart { + function_call: GoogleFunctionCall, #[serde(skip_serializing_if = "Option::is_none")] - pub thought_signature: Option, + thought_signature: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct FunctionResponsePart { - pub function_response: FunctionResponse, +struct GoogleFunctionCall { + name: String, + args: serde_json::Value, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct ThoughtPart { - pub thought: bool, - pub thought_signature: String, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CitationSource { - #[serde(skip_serializing_if = "Option::is_none")] - pub start_index: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub end_index: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub uri: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub license: Option, +struct GoogleFunctionResponsePart { + function_response: GoogleFunctionResponse, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct CitationMetadata { - pub citation_sources: Vec, +struct GoogleFunctionResponse { + name: String, + response: serde_json::Value, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] -pub struct PromptFeedback { - #[serde(skip_serializing_if = "Option::is_none")] - pub block_reason: Option, - pub safety_ratings: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub block_reason_message: Option, +struct GoogleThoughtPart { + thought: bool, + thought_signature: String, } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct UsageMetadata { - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_token_count: Option, +struct GoogleGenerationConfig { #[serde(skip_serializing_if = "Option::is_none")] - pub cached_content_token_count: Option, + candidate_count: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub candidates_token_count: Option, + stop_sequences: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub tool_use_prompt_token_count: Option, + max_output_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub thoughts_token_count: Option, + temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub total_token_count: Option, + thinking_config: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct ThinkingConfig { - pub thinking_budget: u32, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct GenerationConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub candidate_count: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_sequences: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_output_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub thinking_config: Option, +struct GoogleThinkingConfig { + thinking_budget: u32, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct SafetySetting { - pub category: HarmCategory, - pub threshold: HarmBlockThreshold, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum HarmCategory { - #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] - Unspecified, - #[serde(rename = "HARM_CATEGORY_DEROGATORY")] - Derogatory, - #[serde(rename = "HARM_CATEGORY_TOXICITY")] - Toxicity, - #[serde(rename = "HARM_CATEGORY_VIOLENCE")] - Violence, - #[serde(rename = "HARM_CATEGORY_SEXUAL")] - Sexual, - #[serde(rename = "HARM_CATEGORY_MEDICAL")] - Medical, - #[serde(rename = "HARM_CATEGORY_DANGEROUS")] - Dangerous, - #[serde(rename = "HARM_CATEGORY_HARASSMENT")] - Harassment, - #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] - HateSpeech, - #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] - SexuallyExplicit, - #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] - DangerousContent, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum HarmBlockThreshold { - #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] - Unspecified, - BlockLowAndAbove, - BlockMediumAndAbove, - BlockOnlyHigh, - BlockNone, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum HarmProbability { - #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] - Unspecified, - Negligible, - Low, - Medium, - High, +struct GoogleTool { + function_declarations: Vec, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct SafetyRating { - pub category: HarmCategory, - pub probability: HarmProbability, +struct GoogleFunctionDeclaration { + name: String, + description: String, + parameters: serde_json::Value, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct CountTokensRequest { - pub generate_content_request: GenerateContentRequest, +struct GoogleToolConfig { + function_calling_config: GoogleFunctionCallingConfig, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub struct CountTokensResponse { - pub total_tokens: u64, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FunctionCall { - pub name: String, - pub args: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FunctionResponse { - pub name: String, - pub response: serde_json::Value, +struct GoogleFunctionCallingConfig { + mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + allowed_function_names: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] -pub struct Tool { - pub function_declarations: Vec, +struct GoogleStreamResponse { + #[serde(default)] + candidates: Vec, + #[serde(default)] + usage_metadata: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] -pub struct ToolConfig { - pub function_calling_config: FunctionCallingConfig, +struct GoogleCandidate { + #[serde(default)] + content: Option, + #[serde(default)] + finish_reason: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] -pub struct FunctionCallingConfig { - pub mode: FunctionCallingMode, - #[serde(skip_serializing_if = "Option::is_none")] - pub allowed_function_names: Option>, +struct GoogleUsageMetadata { + #[serde(default)] + prompt_token_count: u64, + #[serde(default)] + candidates_token_count: u64, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum FunctionCallingMode { - Auto, - Any, - None, -} +fn convert_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result<(GoogleRequest, String), String> { + let real_model_id = + get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?; -#[derive(Debug, Serialize, Deserialize)] -pub struct FunctionDeclaration { - pub name: String, - pub description: String, - pub parameters: serde_json::Value, -} + let supports_thinking = get_model_supports_thinking(model_id); -#[derive(Debug, Default)] -pub struct ModelName { - pub model_id: String, -} + let mut contents: Vec = Vec::new(); + let mut system_parts: Vec = Vec::new(); -impl ModelName { - pub fn is_empty(&self) -> bool { - self.model_id.is_empty() - } -} + for msg in &request.messages { + match msg.role { + LlmMessageRole::System => { + for content in &msg.content { + if let LlmMessageContent::Text(text) = content { + if !text.is_empty() { + system_parts + .push(GooglePart::Text(GoogleTextPart { text: text.clone() })); + } + } + } + } + LlmMessageRole::User => { + let mut parts: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() })); + } + } + LlmMessageContent::Image(img) => { + parts.push(GooglePart::InlineData(GoogleInlineDataPart { + inline_data: GoogleBlob { + mime_type: "image/png".to_string(), + data: img.source.clone(), + }, + })); + } + LlmMessageContent::ToolResult(result) => { + let response_value = match &result.content { + LlmToolResultContent::Text(t) => { + serde_json::json!({ "output": t }) + } + LlmToolResultContent::Image(_) => { + serde_json::json!({ "output": "Tool responded with an image" }) + } + }; + parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart { + function_response: GoogleFunctionResponse { + name: result.tool_name.clone(), + response: response_value, + }, + })); + } + _ => {} + } + } -const MODEL_NAME_PREFIX: &str = "models/"; + if !parts.is_empty() { + contents.push(GoogleContent { + parts, + role: Some("user".to_string()), + }); + } + } + LlmMessageRole::Assistant => { + let mut parts: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() })); + } + } + LlmMessageContent::ToolUse(tool_use) => { + let thought_signature = + tool_use.thought_signature.clone().filter(|s| !s.is_empty()); + + let args: serde_json::Value = + serde_json::from_str(&tool_use.input).unwrap_or_default(); + + parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart { + function_call: GoogleFunctionCall { + name: tool_use.name.clone(), + args, + }, + thought_signature, + })); + } + LlmMessageContent::Thinking(thinking) => { + if let Some(ref signature) = thinking.signature { + if !signature.is_empty() { + parts.push(GooglePart::Thought(GoogleThoughtPart { + thought: true, + thought_signature: signature.clone(), + })); + } + } + } + _ => {} + } + } -impl Serialize for ModelName { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id)) + if !parts.is_empty() { + contents.push(GoogleContent { + parts, + role: Some("model".to_string()), + }); + } + } + } } -} -impl<'de> Deserialize<'de> for ModelName { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let string = String::deserialize(deserializer)?; - if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) { - Ok(Self { - model_id: id.to_string(), + let system_instruction = if system_parts.is_empty() { + None + } else { + Some(GoogleSystemInstruction { + parts: system_parts, + }) + }; + + let tools: Option> = if request.tools.is_empty() { + None + } else { + let declarations: Vec = request + .tools + .iter() + .map(|t| { + let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema) + .unwrap_or(serde_json::Value::Object(Default::default())); + adapt_schema_for_google(&mut parameters); + GoogleFunctionDeclaration { + name: t.name.clone(), + description: t.description.clone(), + parameters, + } }) - } else { - Err(serde::de::Error::custom(format!( - "Expected model name to begin with {}, got: {}", - MODEL_NAME_PREFIX, string - ))) + .collect(); + Some(vec![GoogleTool { + function_declarations: declarations, + }]) + }; + + let tool_config = request.tool_choice.as_ref().map(|tc| { + let mode = match tc { + LlmToolChoice::Auto => "AUTO", + LlmToolChoice::Any => "ANY", + LlmToolChoice::None => "NONE", + }; + GoogleToolConfig { + function_calling_config: GoogleFunctionCallingConfig { + mode: mode.to_string(), + allowed_function_names: None, + }, } - } -} + }); -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] -pub enum Model { - #[serde( - rename = "gemini-2.5-flash-lite", - alias = "gemini-2.5-flash-lite-preview-06-17", - alias = "gemini-2.0-flash-lite-preview" - )] - Gemini25FlashLite, - #[serde( - rename = "gemini-2.5-flash", - alias = "gemini-2.0-flash-thinking-exp", - alias = "gemini-2.5-flash-preview-04-17", - alias = "gemini-2.5-flash-preview-05-20", - alias = "gemini-2.5-flash-preview-latest", - alias = "gemini-2.0-flash" - )] - #[default] - Gemini25Flash, - #[serde( - rename = "gemini-2.5-pro", - alias = "gemini-2.0-pro-exp", - alias = "gemini-2.5-pro-preview-latest", - alias = "gemini-2.5-pro-exp-03-25", - alias = "gemini-2.5-pro-preview-03-25", - alias = "gemini-2.5-pro-preview-05-06", - alias = "gemini-2.5-pro-preview-06-05" - )] - Gemini25Pro, - #[serde(rename = "gemini-3-pro-preview")] - Gemini3Pro, - #[serde(rename = "custom")] - Custom { - name: String, - /// The name displayed in the UI, such as in the assistant panel model dropdown menu. - display_name: Option, - max_tokens: u64, - #[serde(default)] - mode: GoogleModelMode, - }, -} + let thinking_config = if supports_thinking && request.thinking_allowed { + Some(GoogleThinkingConfig { + thinking_budget: 8192, + }) + } else { + None + }; -impl Model { - pub fn default_fast() -> Self { - Self::Gemini25FlashLite + let generation_config = Some(GoogleGenerationConfig { + candidate_count: Some(1), + stop_sequences: if request.stop_sequences.is_empty() { + None + } else { + Some(request.stop_sequences.clone()) + }, + max_output_tokens: None, + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + thinking_config, + }); + + Ok(( + GoogleRequest { + contents, + system_instruction, + generation_config, + tools, + tool_config, + }, + real_model_id.to_string(), + )) +} + +fn parse_stream_line(line: &str) -> Option { + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," { + return None; } - pub fn id(&self) -> &str { - match self { - Self::Gemini25FlashLite => "gemini-2.5-flash-lite", - Self::Gemini25Flash => "gemini-2.5-flash", - Self::Gemini25Pro => "gemini-2.5-pro", - Self::Gemini3Pro => "gemini-3-pro-preview", - Self::Custom { name, .. } => name, - } - } - pub fn request_id(&self) -> &str { - match self { - Self::Gemini25FlashLite => "gemini-2.5-flash-lite", - Self::Gemini25Flash => "gemini-2.5-flash", - Self::Gemini25Pro => "gemini-2.5-pro", - Self::Gemini3Pro => "gemini-3-pro-preview", - Self::Custom { name, .. } => name, - } - } + let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed); + let json_str = json_str.trim_start_matches(',').trim(); - pub fn display_name(&self) -> &str { - match self { - Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite", - Self::Gemini25Flash => "Gemini 2.5 Flash", - Self::Gemini25Pro => "Gemini 2.5 Pro", - Self::Gemini3Pro => "Gemini 3 Pro", - Self::Custom { - name, display_name, .. - } => display_name.as_ref().unwrap_or(name), - } + if json_str.is_empty() { + return None; } - pub fn max_token_count(&self) -> u64 { - match self { - Self::Gemini25FlashLite => 1_048_576, - Self::Gemini25Flash => 1_048_576, - Self::Gemini25Pro => 1_048_576, - Self::Gemini3Pro => 1_048_576, - Self::Custom { max_tokens, .. } => *max_tokens, + serde_json::from_str(json_str).ok() +} + +impl zed::Extension for GoogleAiProvider { + fn new() -> Self { + Self { + streams: Mutex::new(HashMap::new()), + next_stream_id: Mutex::new(0), } } - pub fn max_output_tokens(&self) -> Option { - match self { - Model::Gemini25FlashLite => Some(65_536), - Model::Gemini25Flash => Some(65_536), - Model::Gemini25Pro => Some(65_536), - Model::Gemini3Pro => Some(65_536), - Model::Custom { .. } => None, - } + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: "google-ai".into(), + name: "Google AI".into(), + icon: Some("icons/google-ai.svg".into()), + }] } - pub fn supports_tools(&self) -> bool { - true + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + Ok(MODELS + .iter() + .map(|m| LlmModelInfo { + id: m.display_name.to_string(), + name: m.display_name.to_string(), + max_token_count: m.max_tokens, + max_output_tokens: m.max_output_tokens, + capabilities: LlmModelCapabilities { + supports_images: m.supports_images, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: m.supports_thinking, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_default, + is_default_fast: m.is_default_fast, + }) + .collect()) } - pub fn supports_images(&self) -> bool { - true + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + llm_get_credential("google-ai").is_some() } - pub fn mode(&self) -> GoogleModelMode { - match self { - Self::Gemini25FlashLite - | Self::Gemini25Flash - | Self::Gemini25Pro - | Self::Gemini3Pro => { - GoogleModelMode::Thinking { - // By default these models are set to "auto", so we preserve that behavior - // but indicate they are capable of thinking mode - budget_tokens: None, - } - } - Self::Custom { mode, .. } => *mode, - } + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + Some( + "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(), + ) } -} -impl std::fmt::Display for Model { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.id()) + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + llm_delete_credential("google-ai") } -} -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_function_call_part_with_signature_serializes_correctly() { - let part = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("test_signature".to_string()), + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + let api_key = llm_get_credential("google-ai").ok_or_else(|| { + "No API key configured. Please add your Google AI API key in settings.".to_string() + })?; + + let (google_request, real_model_id) = convert_request(model_id, request)?; + + let body = serde_json::to_vec(&google_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", + real_model_id, api_key + ); + + let http_request = HttpRequest { + method: HttpMethod::Post, + url, + headers: vec![("Content-Type".to_string(), "application/json".to_string())], + body: Some(body), + redirect_policy: RedirectPolicy::FollowAll, }; - let serialized = serde_json::to_value(&part).unwrap(); - - assert_eq!(serialized["functionCall"]["name"], "test_function"); - assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); - assert_eq!(serialized["thoughtSignature"], "test_signature"); - } + let response_stream = http_request + .fetch_stream() + .map_err(|e| format!("HTTP request failed: {}", e))?; - #[test] - fn test_function_call_part_without_signature_omits_field() { - let part = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: None, + let stream_id = { + let mut id_counter = self.next_stream_id.lock().unwrap(); + let id = format!("google-ai-stream-{}", *id_counter); + *id_counter += 1; + id }; - let serialized = serde_json::to_value(&part).unwrap(); + self.streams.lock().unwrap().insert( + stream_id.clone(), + StreamState { + response_stream: Some(response_stream), + buffer: String::new(), + started: false, + stop_reason: None, + wants_tool_use: false, + }, + ); - assert_eq!(serialized["functionCall"]["name"], "test_function"); - assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); - // thoughtSignature field should not be present when None - assert!(serialized.get("thoughtSignature").is_none()); + Ok(stream_id) } - #[test] - fn test_function_call_part_deserializes_with_signature() { - let json = json!({ - "functionCall": { - "name": "test_function", - "args": {"arg": "value"} - }, - "thoughtSignature": "test_signature" - }); + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + let mut streams = self.streams.lock().unwrap(); + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + if !state.started { + state.started = true; + return Ok(Some(LlmCompletionEvent::Started)); + } - let part: FunctionCallPart = serde_json::from_value(json).unwrap(); + let response_stream = state + .response_stream + .as_mut() + .ok_or_else(|| "Stream already closed".to_string())?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if let Some(response) = parse_stream_line(&line) { + for candidate in response.candidates { + if let Some(finish_reason) = &candidate.finish_reason { + state.stop_reason = Some(match finish_reason.as_str() { + "STOP" => { + if state.wants_tool_use { + LlmStopReason::ToolUse + } else { + LlmStopReason::EndTurn + } + } + "MAX_TOKENS" => LlmStopReason::MaxTokens, + "SAFETY" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + }); + } - assert_eq!(part.function_call.name, "test_function"); - assert_eq!(part.thought_signature, Some("test_signature".to_string())); - } + if let Some(content) = candidate.content { + for part in content.parts { + match part { + GooglePart::Text(text_part) => { + if !text_part.text.is_empty() { + return Ok(Some(LlmCompletionEvent::Text( + text_part.text, + ))); + } + } + GooglePart::FunctionCall(fc_part) => { + state.wants_tool_use = true; + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst); + let id = format!( + "{}-{}", + fc_part.function_call.name, next_tool_id + ); + + let thought_signature = + fc_part.thought_signature.filter(|s| !s.is_empty()); + + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id, + name: fc_part.function_call.name, + input: fc_part.function_call.args.to_string(), + is_input_complete: true, + thought_signature, + }))); + } + GooglePart::Thought(thought_part) => { + return Ok(Some(LlmCompletionEvent::Thinking( + LlmThinkingContent { + text: "(Encrypted thought)".to_string(), + signature: Some(thought_part.thought_signature), + }, + ))); + } + _ => {} + } + } + } + } - #[test] - fn test_function_call_part_deserializes_without_signature() { - let json = json!({ - "functionCall": { - "name": "test_function", - "args": {"arg": "value"} - } - }); + if let Some(usage) = response.usage_metadata { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_token_count, + output_tokens: usage.candidates_token_count, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }))); + } + } - let part: FunctionCallPart = serde_json::from_value(json).unwrap(); + continue; + } - assert_eq!(part.function_call.name, "test_function"); - assert_eq!(part.thought_signature, None); - } + match response_stream.next_chunk() { + Ok(Some(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&text); + } + Ok(None) => { + // Stream ended - check if we have a stop reason + if let Some(stop_reason) = state.stop_reason.take() { + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } - #[test] - fn test_function_call_part_round_trip() { - let original = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value", "nested": {"key": "val"}}), - }, - thought_signature: Some("round_trip_signature".to_string()), - }; + // No stop reason - this is unexpected. Check if buffer contains error info + let mut error_msg = String::from("Stream ended unexpectedly."); - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap(); + // Try to parse remaining buffer as potential error response + if !state.buffer.is_empty() { + error_msg.push_str(&format!( + "\nRemaining buffer: {}", + &state.buffer[..state.buffer.len().min(1000)] + )); + } - assert_eq!(deserialized.function_call.name, original.function_call.name); - assert_eq!(deserialized.function_call.args, original.function_call.args); - assert_eq!(deserialized.thought_signature, original.thought_signature); + return Err(error_msg); + } + Err(e) => { + return Err(format!("Stream error: {}", e)); + } + } + } } - #[test] - fn test_function_call_part_with_empty_signature_serializes() { - let part = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("".to_string()), - }; - - let serialized = serde_json::to_value(&part).unwrap(); - - // Empty string should still be serialized (normalization happens at a higher level) - assert_eq!(serialized["thoughtSignature"], ""); + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.lock().unwrap().remove(stream_id); } } + +zed::register_extension!(GoogleAiProvider);