diff --git a/extensions/google-ai/src/google_ai.rs b/extensions/google-ai/src/google_ai.rs index 846ad348b18bfd329c563be965a26419f8103f8b..bad0bf43692bc6fbd4bbabe4de388d2a7f1b6fde 100644 --- a/extensions/google-ai/src/google_ai.rs +++ b/extensions/google-ai/src/google_ai.rs @@ -1,798 +1,718 @@ -use std::collections::HashMap; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Mutex; - -use serde::{Deserialize, Serialize}; -use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; -use zed_extension_api::{self as zed, *}; - -static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); - -struct GoogleAiProvider { - streams: Mutex>, - next_stream_id: Mutex, -} - -struct StreamState { - response_stream: Option, - buffer: String, - started: bool, - stop_reason: Option, - wants_tool_use: bool, -} - -struct ModelDefinition { - real_id: &'static str, - display_name: &'static str, - max_tokens: u64, - max_output_tokens: Option, - supports_images: bool, - supports_thinking: bool, - is_default: bool, - is_default_fast: bool, -} - -const MODELS: &[ModelDefinition] = &[ - ModelDefinition { - real_id: "gemini-2.5-flash-lite", - display_name: "Gemini 2.5 Flash-Lite", - max_tokens: 1_048_576, - max_output_tokens: Some(65_536), - supports_images: true, - supports_thinking: true, - is_default: false, - is_default_fast: true, - }, - ModelDefinition { - real_id: "gemini-2.5-flash", - display_name: "Gemini 2.5 Flash", - max_tokens: 1_048_576, - max_output_tokens: Some(65_536), - supports_images: true, - supports_thinking: true, - is_default: true, - is_default_fast: false, - }, - ModelDefinition { - real_id: "gemini-2.5-pro", - display_name: "Gemini 2.5 Pro", - max_tokens: 1_048_576, - max_output_tokens: Some(65_536), - supports_images: true, - supports_thinking: true, - is_default: false, - is_default_fast: false, - }, - ModelDefinition { - real_id: "gemini-3-pro-preview", - display_name: "Gemini 3 Pro", - max_tokens: 1_048_576, - max_output_tokens: Some(65_536), - supports_images: true, - supports_thinking: true, - is_default: false, - is_default_fast: false, - }, -]; -fn get_real_model_id(display_name: &str) -> Option<&'static str> { - MODELS - .iter() - .find(|m| m.display_name == display_name) - .map(|m| m.real_id) +use std::mem; + +use anyhow::{Result, anyhow, bail}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +pub use settings::ModelMode as GoogleModelMode; + +pub const API_URL: &str = "https://generativelanguage.googleapis.com"; + +pub async fn stream_generate_content( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + mut request: GenerateContentRequest, +) -> Result>> { + let api_key = api_key.trim(); + validate_generate_content_request(&request)?; + + // The `model` field is emptied as it is provided as a path parameter. + let model_id = mem::take(&mut request.model.model_id); + + let uri = + format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",); + + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + if let Some(line) = line.strip_prefix("data: ") { + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(format!( + "Error parsing JSON: {error:?}\n{line:?}" + )))), + } + } else { + None + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut text = String::new(); + response.body_mut().read_to_string(&mut text).await?; + Err(anyhow!( + "error during streamGenerateContent, status code: {:?}, body: {}", + response.status(), + text + )) + } } -fn get_model_supports_thinking(display_name: &str) -> bool { - MODELS - .iter() - .find(|m| m.display_name == display_name) - .map(|m| m.supports_thinking) - .unwrap_or(false) -} - -/// Adapts a JSON schema to be compatible with Google's API subset. -/// Google only supports a specific subset of JSON Schema fields. -/// See: https://ai.google.dev/api/caching#Schema -fn adapt_schema_for_google(json: &mut serde_json::Value) { - adapt_schema_for_google_impl(json, true); -} - -fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) { - if let serde_json::Value::Object(obj) = json { - // Google's Schema only supports these fields: - // type, format, title, description, nullable, enum, maxItems, minItems, - // properties, required, minProperties, maxProperties, minLength, maxLength, - // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum - const ALLOWED_KEYS: &[&str] = &[ - "type", - "format", - "title", - "description", - "nullable", - "enum", - "maxItems", - "minItems", - "properties", - "required", - "minProperties", - "maxProperties", - "minLength", - "maxLength", - "pattern", - "example", - "anyOf", - "propertyOrdering", - "default", - "items", - "minimum", - "maximum", - ]; - - // Convert oneOf to anyOf before filtering keys - if let Some(one_of) = obj.remove("oneOf") { - obj.insert("anyOf".to_string(), one_of); - } +pub async fn count_tokens( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CountTokensRequest, +) -> Result { + validate_generate_content_request(&request.generate_content_request)?; + + let uri = format!( + "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", + model_id = &request.generate_content_request.model.model_id, + ); + + let request = serde_json::to_string(&request)?; + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(&uri) + .header("Content-Type", "application/json"); + let http_request = request_builder.body(AsyncBody::from(request))?; + + let mut response = client.send(http_request).await?; + let mut text = String::new(); + response.body_mut().read_to_string(&mut text).await?; + anyhow::ensure!( + response.status().is_success(), + "error during countTokens, status code: {:?}, body: {}", + response.status(), + text + ); + Ok(serde_json::from_str::(&text)?) +} - // If type is an array (e.g., ["string", "null"]), take just the first type - if let Some(type_field) = obj.get_mut("type") { - if let serde_json::Value::Array(types) = type_field { - if let Some(first_type) = types.first().cloned() { - *type_field = first_type; - } - } - } +pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { + if request.model.is_empty() { + bail!("Model must be specified"); + } - // Only filter keys if this is a schema object, not a properties map - if is_schema { - obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str())); - } + if request.contents.is_empty() { + bail!("Request must contain at least one content item"); + } - // Recursively process nested values - // "properties" contains a map of property names -> schemas - // "items" and "anyOf" contain schemas directly - for (key, value) in obj.iter_mut() { - if key == "properties" { - // properties is a map of property_name -> schema - if let serde_json::Value::Object(props) = value { - for (_, prop_schema) in props.iter_mut() { - adapt_schema_for_google_impl(prop_schema, true); - } - } - } else if key == "items" { - // items is a schema - adapt_schema_for_google_impl(value, true); - } else if key == "anyOf" { - // anyOf is an array of schemas - if let serde_json::Value::Array(arr) = value { - for item in arr.iter_mut() { - adapt_schema_for_google_impl(item, true); - } - } - } - } - } else if let serde_json::Value::Array(arr) = json { - for item in arr.iter_mut() { - adapt_schema_for_google_impl(item, true); - } + if let Some(user_content) = request + .contents + .iter() + .find(|content| content.role == Role::User) + && user_content.parts.is_empty() + { + bail!("User content must contain at least one part"); } + + Ok(()) } -#[derive(Serialize)] +#[derive(Debug, Serialize, Deserialize)] +pub enum Task { + #[serde(rename = "generateContent")] + GenerateContent, + #[serde(rename = "streamGenerateContent")] + StreamGenerateContent, + #[serde(rename = "countTokens")] + CountTokens, + #[serde(rename = "embedContent")] + EmbedContent, + #[serde(rename = "batchEmbedContents")] + BatchEmbedContents, +} + +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleRequest { - contents: Vec, +pub struct GenerateContentRequest { + #[serde(default, skip_serializing_if = "ModelName::is_empty")] + pub model: ModelName, + pub contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_instruction: Option, #[serde(skip_serializing_if = "Option::is_none")] - system_instruction: Option, + pub generation_config: Option, #[serde(skip_serializing_if = "Option::is_none")] - generation_config: Option, + pub safety_settings: Option>, #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, + pub tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] - tool_config: Option, + pub tool_config: Option, } -#[derive(Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleSystemInstruction { - parts: Vec, +pub struct GenerateContentResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub candidates: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_feedback: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage_metadata: Option, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleContent { - parts: Vec, +pub struct GenerateContentCandidate { + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option, + pub content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_message: Option, #[serde(skip_serializing_if = "Option::is_none")] - role: Option, + pub safety_ratings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub citation_metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Content { + #[serde(default)] + pub parts: Vec, + pub role: Role, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemInstruction { + pub parts: Vec, +} + +#[derive(Debug, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub enum Role { + User, + Model, +} + +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] -enum GooglePart { - Text(GoogleTextPart), - InlineData(GoogleInlineDataPart), - FunctionCall(GoogleFunctionCallPart), - FunctionResponse(GoogleFunctionResponsePart), - Thought(GoogleThoughtPart), +pub enum Part { + TextPart(TextPart), + InlineDataPart(InlineDataPart), + FunctionCallPart(FunctionCallPart), + FunctionResponsePart(FunctionResponsePart), + ThoughtPart(ThoughtPart), } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleTextPart { - text: String, +pub struct TextPart { + pub text: String, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleInlineDataPart { - inline_data: GoogleBlob, +pub struct InlineDataPart { + pub inline_data: GenerativeContentBlob, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleBlob { - mime_type: String, - data: String, +pub struct GenerativeContentBlob { + pub mime_type: String, + pub data: String, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleFunctionCallPart { - function_call: GoogleFunctionCall, +pub struct FunctionCallPart { + pub function_call: FunctionCall, + /// Thought signature returned by the model for function calls. + /// Only present on the first function call in parallel call scenarios. #[serde(skip_serializing_if = "Option::is_none")] - thought_signature: Option, + pub thought_signature: Option, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleFunctionCall { - name: String, - args: serde_json::Value, +pub struct FunctionResponsePart { + pub function_response: FunctionResponse, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleFunctionResponsePart { - function_response: GoogleFunctionResponse, +pub struct ThoughtPart { + pub thought: bool, + pub thought_signature: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CitationSource { + #[serde(skip_serializing_if = "Option::is_none")] + pub start_index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub end_index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub uri: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license: Option, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleFunctionResponse { - name: String, - response: serde_json::Value, +pub struct CitationMetadata { + pub citation_sources: Vec, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleThoughtPart { - thought: bool, - thought_signature: String, +pub struct PromptFeedback { + #[serde(skip_serializing_if = "Option::is_none")] + pub block_reason: Option, + pub safety_ratings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub block_reason_message: Option, } -#[derive(Serialize)] +#[derive(Debug, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] -struct GoogleGenerationConfig { +pub struct UsageMetadata { + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] - candidate_count: Option, + pub cached_content_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] - stop_sequences: Option>, + pub candidates_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] - max_output_tokens: Option, + pub tool_use_prompt_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, + pub thoughts_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] - thinking_config: Option, + pub total_token_count: Option, } -#[derive(Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleThinkingConfig { - thinking_budget: u32, +pub struct ThinkingConfig { + pub thinking_budget: u32, } -#[derive(Serialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -struct GoogleTool { - function_declarations: Vec, +pub struct GenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub candidate_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_config: Option, } -#[derive(Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleFunctionDeclaration { - name: String, - description: String, - parameters: serde_json::Value, +pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, } -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct GoogleToolConfig { - function_calling_config: GoogleFunctionCallingConfig, +#[derive(Debug, Serialize, Deserialize)] +pub enum HarmCategory { + #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] + Unspecified, + #[serde(rename = "HARM_CATEGORY_DEROGATORY")] + Derogatory, + #[serde(rename = "HARM_CATEGORY_TOXICITY")] + Toxicity, + #[serde(rename = "HARM_CATEGORY_VIOLENCE")] + Violence, + #[serde(rename = "HARM_CATEGORY_SEXUAL")] + Sexual, + #[serde(rename = "HARM_CATEGORY_MEDICAL")] + Medical, + #[serde(rename = "HARM_CATEGORY_DANGEROUS")] + Dangerous, + #[serde(rename = "HARM_CATEGORY_HARASSMENT")] + Harassment, + #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] + HateSpeech, + #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] + SexuallyExplicit, + #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] + DangerousContent, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmBlockThreshold { + #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] + Unspecified, + BlockLowAndAbove, + BlockMediumAndAbove, + BlockOnlyHigh, + BlockNone, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmProbability { + #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] + Unspecified, + Negligible, + Low, + Medium, + High, } -#[derive(Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleFunctionCallingConfig { - mode: String, - #[serde(skip_serializing_if = "Option::is_none")] - allowed_function_names: Option>, +pub struct SafetyRating { + pub category: HarmCategory, + pub probability: HarmProbability, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleStreamResponse { - #[serde(default)] - candidates: Vec, - #[serde(default)] - usage_metadata: Option, +pub struct CountTokensRequest { + pub generate_content_request: GenerateContentRequest, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleCandidate { - #[serde(default)] - content: Option, - #[serde(default)] - finish_reason: Option, +pub struct CountTokensResponse { + pub total_tokens: u64, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub args: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionResponse { + pub name: String, + pub response: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -struct GoogleUsageMetadata { - #[serde(default)] - prompt_token_count: u64, - #[serde(default)] - candidates_token_count: u64, +pub struct Tool { + pub function_declarations: Vec, } -fn convert_request( - model_id: &str, - request: &LlmCompletionRequest, -) -> Result<(GoogleRequest, String), String> { - let real_model_id = - get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?; +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolConfig { + pub function_calling_config: FunctionCallingConfig, +} - let supports_thinking = get_model_supports_thinking(model_id); +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionCallingConfig { + pub mode: FunctionCallingMode, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_function_names: Option>, +} - let mut contents: Vec = Vec::new(); - let mut system_parts: Vec = Vec::new(); +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum FunctionCallingMode { + Auto, + Any, + None, +} - for msg in &request.messages { - match msg.role { - LlmMessageRole::System => { - for content in &msg.content { - if let LlmMessageContent::Text(text) = content { - if !text.is_empty() { - system_parts - .push(GooglePart::Text(GoogleTextPart { text: text.clone() })); - } - } - } - } - LlmMessageRole::User => { - let mut parts: Vec = Vec::new(); - - for content in &msg.content { - match content { - LlmMessageContent::Text(text) => { - if !text.is_empty() { - parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() })); - } - } - LlmMessageContent::Image(img) => { - parts.push(GooglePart::InlineData(GoogleInlineDataPart { - inline_data: GoogleBlob { - mime_type: "image/png".to_string(), - data: img.source.clone(), - }, - })); - } - LlmMessageContent::ToolResult(result) => { - let response_value = match &result.content { - LlmToolResultContent::Text(t) => { - serde_json::json!({ "output": t }) - } - LlmToolResultContent::Image(_) => { - serde_json::json!({ "output": "Tool responded with an image" }) - } - }; - parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart { - function_response: GoogleFunctionResponse { - name: result.tool_name.clone(), - response: response_value, - }, - })); - } - _ => {} - } - } +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} - if !parts.is_empty() { - contents.push(GoogleContent { - parts, - role: Some("user".to_string()), - }); - } - } - LlmMessageRole::Assistant => { - let mut parts: Vec = Vec::new(); - - for content in &msg.content { - match content { - LlmMessageContent::Text(text) => { - if !text.is_empty() { - parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() })); - } - } - LlmMessageContent::ToolUse(tool_use) => { - let thought_signature = - tool_use.thought_signature.clone().filter(|s| !s.is_empty()); - - let args: serde_json::Value = - serde_json::from_str(&tool_use.input).unwrap_or_default(); - - parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart { - function_call: GoogleFunctionCall { - name: tool_use.name.clone(), - args, - }, - thought_signature, - })); - } - LlmMessageContent::Thinking(thinking) => { - if let Some(ref signature) = thinking.signature { - if !signature.is_empty() { - parts.push(GooglePart::Thought(GoogleThoughtPart { - thought: true, - thought_signature: signature.clone(), - })); - } - } - } - _ => {} - } - } +#[derive(Debug, Default)] +pub struct ModelName { + pub model_id: String, +} - if !parts.is_empty() { - contents.push(GoogleContent { - parts, - role: Some("model".to_string()), - }); - } - } - } +impl ModelName { + pub fn is_empty(&self) -> bool { + self.model_id.is_empty() } +} - let system_instruction = if system_parts.is_empty() { - None - } else { - Some(GoogleSystemInstruction { - parts: system_parts, - }) - }; +const MODEL_NAME_PREFIX: &str = "models/"; - let tools: Option> = if request.tools.is_empty() { - None - } else { - let declarations: Vec = request - .tools - .iter() - .map(|t| { - let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema) - .unwrap_or(serde_json::Value::Object(Default::default())); - adapt_schema_for_google(&mut parameters); - GoogleFunctionDeclaration { - name: t.name.clone(), - description: t.description.clone(), - parameters, - } +impl Serialize for ModelName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id)) + } +} + +impl<'de> Deserialize<'de> for ModelName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let string = String::deserialize(deserializer)?; + if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) { + Ok(Self { + model_id: id.to_string(), }) - .collect(); - Some(vec![GoogleTool { - function_declarations: declarations, - }]) - }; - - let tool_config = request.tool_choice.as_ref().map(|tc| { - let mode = match tc { - LlmToolChoice::Auto => "AUTO", - LlmToolChoice::Any => "ANY", - LlmToolChoice::None => "NONE", - }; - GoogleToolConfig { - function_calling_config: GoogleFunctionCallingConfig { - mode: mode.to_string(), - allowed_function_names: None, - }, + } else { + Err(serde::de::Error::custom(format!( + "Expected model name to begin with {}, got: {}", + MODEL_NAME_PREFIX, string + ))) } - }); + } +} - let thinking_config = if supports_thinking && request.thinking_allowed { - Some(GoogleThinkingConfig { - thinking_budget: 8192, - }) - } else { - None - }; +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +pub enum Model { + #[serde( + rename = "gemini-2.5-flash-lite", + alias = "gemini-2.5-flash-lite-preview-06-17", + alias = "gemini-2.0-flash-lite-preview" + )] + Gemini25FlashLite, + #[serde( + rename = "gemini-2.5-flash", + alias = "gemini-2.0-flash-thinking-exp", + alias = "gemini-2.5-flash-preview-04-17", + alias = "gemini-2.5-flash-preview-05-20", + alias = "gemini-2.5-flash-preview-latest", + alias = "gemini-2.0-flash" + )] + #[default] + Gemini25Flash, + #[serde( + rename = "gemini-2.5-pro", + alias = "gemini-2.0-pro-exp", + alias = "gemini-2.5-pro-preview-latest", + alias = "gemini-2.5-pro-exp-03-25", + alias = "gemini-2.5-pro-preview-03-25", + alias = "gemini-2.5-pro-preview-05-06", + alias = "gemini-2.5-pro-preview-06-05" + )] + Gemini25Pro, + #[serde(rename = "gemini-3-pro-preview")] + Gemini3Pro, + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_tokens: u64, + #[serde(default)] + mode: GoogleModelMode, + }, +} - let generation_config = Some(GoogleGenerationConfig { - candidate_count: Some(1), - stop_sequences: if request.stop_sequences.is_empty() { - None - } else { - Some(request.stop_sequences.clone()) - }, - max_output_tokens: None, - temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), - thinking_config, - }); - - Ok(( - GoogleRequest { - contents, - system_instruction, - generation_config, - tools, - tool_config, - }, - real_model_id.to_string(), - )) -} - -fn parse_stream_line(line: &str) -> Option { - let trimmed = line.trim(); - if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," { - return None; +impl Model { + pub fn default_fast() -> Self { + Self::Gemini25FlashLite } - let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed); - let json_str = json_str.trim_start_matches(',').trim(); - - if json_str.is_empty() { - return None; + pub fn id(&self) -> &str { + match self { + Self::Gemini25FlashLite => "gemini-2.5-flash-lite", + Self::Gemini25Flash => "gemini-2.5-flash", + Self::Gemini25Pro => "gemini-2.5-pro", + Self::Gemini3Pro => "gemini-3-pro-preview", + Self::Custom { name, .. } => name, + } + } + pub fn request_id(&self) -> &str { + match self { + Self::Gemini25FlashLite => "gemini-2.5-flash-lite", + Self::Gemini25Flash => "gemini-2.5-flash", + Self::Gemini25Pro => "gemini-2.5-pro", + Self::Gemini3Pro => "gemini-3-pro-preview", + Self::Custom { name, .. } => name, + } } - serde_json::from_str(json_str).ok() -} + pub fn display_name(&self) -> &str { + match self { + Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite", + Self::Gemini25Flash => "Gemini 2.5 Flash", + Self::Gemini25Pro => "Gemini 2.5 Pro", + Self::Gemini3Pro => "Gemini 3 Pro", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } -impl zed::Extension for GoogleAiProvider { - fn new() -> Self { - Self { - streams: Mutex::new(HashMap::new()), - next_stream_id: Mutex::new(0), + pub fn max_token_count(&self) -> u64 { + match self { + Self::Gemini25FlashLite => 1_048_576, + Self::Gemini25Flash => 1_048_576, + Self::Gemini25Pro => 1_048_576, + Self::Gemini3Pro => 1_048_576, + Self::Custom { max_tokens, .. } => *max_tokens, } } - fn llm_providers(&self) -> Vec { - vec![LlmProviderInfo { - id: "google-ai".into(), - name: "Google AI".into(), - icon: Some("icons/google-ai.svg".into()), - }] + pub fn max_output_tokens(&self) -> Option { + match self { + Model::Gemini25FlashLite => Some(65_536), + Model::Gemini25Flash => Some(65_536), + Model::Gemini25Pro => Some(65_536), + Model::Gemini3Pro => Some(65_536), + Model::Custom { .. } => None, + } } - fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { - Ok(MODELS - .iter() - .map(|m| LlmModelInfo { - id: m.display_name.to_string(), - name: m.display_name.to_string(), - max_token_count: m.max_tokens, - max_output_tokens: m.max_output_tokens, - capabilities: LlmModelCapabilities { - supports_images: m.supports_images, - supports_tools: true, - supports_tool_choice_auto: true, - supports_tool_choice_any: true, - supports_tool_choice_none: true, - supports_thinking: m.supports_thinking, - tool_input_format: LlmToolInputFormat::JsonSchema, - }, - is_default: m.is_default, - is_default_fast: m.is_default_fast, - }) - .collect()) + pub fn supports_tools(&self) -> bool { + true } - fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { - llm_get_credential("google-ai").is_some() + pub fn supports_images(&self) -> bool { + true } - fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { - Some( - "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(), - ) + pub fn mode(&self) -> GoogleModelMode { + match self { + Self::Gemini25FlashLite + | Self::Gemini25Flash + | Self::Gemini25Pro + | Self::Gemini3Pro => { + GoogleModelMode::Thinking { + // By default these models are set to "auto", so we preserve that behavior + // but indicate they are capable of thinking mode + budget_tokens: None, + } + } + Self::Custom { mode, .. } => *mode, + } } +} - fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { - llm_delete_credential("google-ai") +impl std::fmt::Display for Model { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id()) } +} - fn llm_stream_completion_start( - &mut self, - _provider_id: &str, - model_id: &str, - request: &LlmCompletionRequest, - ) -> Result { - let api_key = llm_get_credential("google-ai").ok_or_else(|| { - "No API key configured. Please add your Google AI API key in settings.".to_string() - })?; - - let (google_request, real_model_id) = convert_request(model_id, request)?; - - let body = serde_json::to_vec(&google_request) - .map_err(|e| format!("Failed to serialize request: {}", e))?; - - let url = format!( - "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", - real_model_id, api_key - ); - - let http_request = HttpRequest { - method: HttpMethod::Post, - url, - headers: vec![("Content-Type".to_string(), "application/json".to_string())], - body: Some(body), - redirect_policy: RedirectPolicy::FollowAll, +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_function_call_part_with_signature_serializes_correctly() { + let part = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("test_signature".to_string()), }; - let response_stream = http_request - .fetch_stream() - .map_err(|e| format!("HTTP request failed: {}", e))?; + let serialized = serde_json::to_value(&part).unwrap(); - let stream_id = { - let mut id_counter = self.next_stream_id.lock().unwrap(); - let id = format!("google-ai-stream-{}", *id_counter); - *id_counter += 1; - id - }; + assert_eq!(serialized["functionCall"]["name"], "test_function"); + assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); + assert_eq!(serialized["thoughtSignature"], "test_signature"); + } - self.streams.lock().unwrap().insert( - stream_id.clone(), - StreamState { - response_stream: Some(response_stream), - buffer: String::new(), - started: false, - stop_reason: None, - wants_tool_use: false, + #[test] + fn test_function_call_part_without_signature_omits_field() { + let part = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), }, - ); + thought_signature: None, + }; + + let serialized = serde_json::to_value(&part).unwrap(); - Ok(stream_id) + assert_eq!(serialized["functionCall"]["name"], "test_function"); + assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); + // thoughtSignature field should not be present when None + assert!(serialized.get("thoughtSignature").is_none()); } - fn llm_stream_completion_next( - &mut self, - stream_id: &str, - ) -> Result, String> { - let mut streams = self.streams.lock().unwrap(); - let state = streams - .get_mut(stream_id) - .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; - - if !state.started { - state.started = true; - return Ok(Some(LlmCompletionEvent::Started)); - } + #[test] + fn test_function_call_part_deserializes_with_signature() { + let json = json!({ + "functionCall": { + "name": "test_function", + "args": {"arg": "value"} + }, + "thoughtSignature": "test_signature" + }); - let response_stream = state - .response_stream - .as_mut() - .ok_or_else(|| "Stream already closed".to_string())?; - - loop { - if let Some(newline_pos) = state.buffer.find('\n') { - let line = state.buffer[..newline_pos].to_string(); - state.buffer = state.buffer[newline_pos + 1..].to_string(); - - if let Some(response) = parse_stream_line(&line) { - for candidate in response.candidates { - if let Some(finish_reason) = &candidate.finish_reason { - state.stop_reason = Some(match finish_reason.as_str() { - "STOP" => { - if state.wants_tool_use { - LlmStopReason::ToolUse - } else { - LlmStopReason::EndTurn - } - } - "MAX_TOKENS" => LlmStopReason::MaxTokens, - "SAFETY" => LlmStopReason::Refusal, - _ => LlmStopReason::EndTurn, - }); - } + let part: FunctionCallPart = serde_json::from_value(json).unwrap(); - if let Some(content) = candidate.content { - for part in content.parts { - match part { - GooglePart::Text(text_part) => { - if !text_part.text.is_empty() { - return Ok(Some(LlmCompletionEvent::Text( - text_part.text, - ))); - } - } - GooglePart::FunctionCall(fc_part) => { - state.wants_tool_use = true; - let next_tool_id = - TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst); - let id = format!( - "{}-{}", - fc_part.function_call.name, next_tool_id - ); - - let thought_signature = - fc_part.thought_signature.filter(|s| !s.is_empty()); - - return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { - id, - name: fc_part.function_call.name, - input: fc_part.function_call.args.to_string(), - is_input_complete: true, - thought_signature, - }))); - } - GooglePart::Thought(thought_part) => { - return Ok(Some(LlmCompletionEvent::Thinking( - LlmThinkingContent { - text: "(Encrypted thought)".to_string(), - signature: Some(thought_part.thought_signature), - }, - ))); - } - _ => {} - } - } - } - } - - if let Some(usage) = response.usage_metadata { - return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { - input_tokens: usage.prompt_token_count, - output_tokens: usage.candidates_token_count, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }))); - } - } + assert_eq!(part.function_call.name, "test_function"); + assert_eq!(part.thought_signature, Some("test_signature".to_string())); + } - continue; + #[test] + fn test_function_call_part_deserializes_without_signature() { + let json = json!({ + "functionCall": { + "name": "test_function", + "args": {"arg": "value"} } + }); - match response_stream.next_chunk() { - Ok(Some(chunk)) => { - let text = String::from_utf8_lossy(&chunk); - state.buffer.push_str(&text); - } - Ok(None) => { - // Stream ended - check if we have a stop reason - if let Some(stop_reason) = state.stop_reason.take() { - return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); - } + let part: FunctionCallPart = serde_json::from_value(json).unwrap(); - // No stop reason - this is unexpected. Check if buffer contains error info - let mut error_msg = String::from("Stream ended unexpectedly."); + assert_eq!(part.function_call.name, "test_function"); + assert_eq!(part.thought_signature, None); + } - // Try to parse remaining buffer as potential error response - if !state.buffer.is_empty() { - error_msg.push_str(&format!( - "\nRemaining buffer: {}", - &state.buffer[..state.buffer.len().min(1000)] - )); - } + #[test] + fn test_function_call_part_round_trip() { + let original = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value", "nested": {"key": "val"}}), + }, + thought_signature: Some("round_trip_signature".to_string()), + }; - return Err(error_msg); - } - Err(e) => { - return Err(format!("Stream error: {}", e)); - } - } - } + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.function_call.name, original.function_call.name); + assert_eq!(deserialized.function_call.args, original.function_call.args); + assert_eq!(deserialized.thought_signature, original.thought_signature); } - fn llm_stream_completion_close(&mut self, stream_id: &str) { - self.streams.lock().unwrap().remove(stream_id); + #[test] + fn test_function_call_part_with_empty_signature_serializes() { + let part = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("".to_string()), + }; + + let serialized = serde_json::to_value(&part).unwrap(); + + // Empty string should still be serialized (normalization happens at a higher level) + assert_eq!(serialized["thoughtSignature"], ""); } } - -zed::register_extension!(GoogleAiProvider);