@@ -1,717 +1,798 @@
-use std::mem;
-
-use anyhow::{Result, anyhow, bail};
-use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
-pub use settings::ModelMode as GoogleModelMode;
-
-pub const API_URL: &str = "https://generativelanguage.googleapis.com";
-
-pub async fn stream_generate_content(
- client: &dyn HttpClient,
- api_url: &str,
- api_key: &str,
- mut request: GenerateContentRequest,
-) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
- let api_key = api_key.trim();
- validate_generate_content_request(&request)?;
-
- // The `model` field is emptied as it is provided as a path parameter.
- let model_id = mem::take(&mut request.model.model_id);
-
- let uri =
- format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
-
- let request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(uri)
- .header("Content-Type", "application/json");
-
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
- let mut response = client.send(request).await?;
- if response.status().is_success() {
- let reader = BufReader::new(response.into_body());
- Ok(reader
- .lines()
- .filter_map(|line| async move {
- match line {
- Ok(line) => {
- if let Some(line) = line.strip_prefix("data: ") {
- match serde_json::from_str(line) {
- Ok(response) => Some(Ok(response)),
- Err(error) => Some(Err(anyhow!(format!(
- "Error parsing JSON: {error:?}\n{line:?}"
- )))),
- }
- } else {
- None
- }
- }
- Err(error) => Some(Err(anyhow!(error))),
- }
- })
- .boxed())
- } else {
- let mut text = String::new();
- response.body_mut().read_to_string(&mut text).await?;
- Err(anyhow!(
- "error during streamGenerateContent, status code: {:?}, body: {}",
- response.status(),
- text
- ))
- }
-}
+use std::collections::HashMap;
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::Mutex;
+
+use serde::{Deserialize, Serialize};
+use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
+use zed_extension_api::{self as zed, *};
+
+static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
+
+struct GoogleAiProvider {
+ streams: Mutex<HashMap<String, StreamState>>,
+ next_stream_id: Mutex<u64>,
+}
+
+struct StreamState {
+ response_stream: Option<HttpResponseStream>,
+ buffer: String,
+ started: bool,
+ stop_reason: Option<LlmStopReason>,
+ wants_tool_use: bool,
+}
+
+struct ModelDefinition {
+ real_id: &'static str,
+ display_name: &'static str,
+ max_tokens: u64,
+ max_output_tokens: Option<u64>,
+ supports_images: bool,
+ supports_thinking: bool,
+ is_default: bool,
+ is_default_fast: bool,
+}
+
+const MODELS: &[ModelDefinition] = &[
+ ModelDefinition {
+ real_id: "gemini-2.5-flash-lite",
+ display_name: "Gemini 2.5 Flash-Lite",
+ max_tokens: 1_048_576,
+ max_output_tokens: Some(65_536),
+ supports_images: true,
+ supports_thinking: true,
+ is_default: false,
+ is_default_fast: true,
+ },
+ ModelDefinition {
+ real_id: "gemini-2.5-flash",
+ display_name: "Gemini 2.5 Flash",
+ max_tokens: 1_048_576,
+ max_output_tokens: Some(65_536),
+ supports_images: true,
+ supports_thinking: true,
+ is_default: true,
+ is_default_fast: false,
+ },
+ ModelDefinition {
+ real_id: "gemini-2.5-pro",
+ display_name: "Gemini 2.5 Pro",
+ max_tokens: 1_048_576,
+ max_output_tokens: Some(65_536),
+ supports_images: true,
+ supports_thinking: true,
+ is_default: false,
+ is_default_fast: false,
+ },
+ ModelDefinition {
+ real_id: "gemini-3-pro-preview",
+ display_name: "Gemini 3 Pro",
+ max_tokens: 1_048_576,
+ max_output_tokens: Some(65_536),
+ supports_images: true,
+ supports_thinking: true,
+ is_default: false,
+ is_default_fast: false,
+ },
+];
-pub async fn count_tokens(
- client: &dyn HttpClient,
- api_url: &str,
- api_key: &str,
- request: CountTokensRequest,
-) -> Result<CountTokensResponse> {
- validate_generate_content_request(&request.generate_content_request)?;
-
- let uri = format!(
- "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
- model_id = &request.generate_content_request.model.model_id,
- );
-
- let request = serde_json::to_string(&request)?;
- let request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(&uri)
- .header("Content-Type", "application/json");
- let http_request = request_builder.body(AsyncBody::from(request))?;
-
- let mut response = client.send(http_request).await?;
- let mut text = String::new();
- response.body_mut().read_to_string(&mut text).await?;
- anyhow::ensure!(
- response.status().is_success(),
- "error during countTokens, status code: {:?}, body: {}",
- response.status(),
- text
- );
- Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+fn get_real_model_id(display_name: &str) -> Option<&'static str> {
+ MODELS
+ .iter()
+ .find(|m| m.display_name == display_name)
+ .map(|m| m.real_id)
}
-pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
- if request.model.is_empty() {
- bail!("Model must be specified");
- }
-
- if request.contents.is_empty() {
- bail!("Request must contain at least one content item");
- }
-
- if let Some(user_content) = request
- .contents
+fn get_model_supports_thinking(display_name: &str) -> bool {
+ MODELS
.iter()
- .find(|content| content.role == Role::User)
- && user_content.parts.is_empty()
- {
- bail!("User content must contain at least one part");
- }
-
- Ok(())
-}
+ .find(|m| m.display_name == display_name)
+ .map(|m| m.supports_thinking)
+ .unwrap_or(false)
+}
+
+/// Adapts a JSON schema to be compatible with Google's API subset.
+/// Google only supports a specific subset of JSON Schema fields.
+/// See: https://ai.google.dev/api/caching#Schema
+fn adapt_schema_for_google(json: &mut serde_json::Value) {
+ adapt_schema_for_google_impl(json, true);
+}
+
+fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
+ if let serde_json::Value::Object(obj) = json {
+ // Google's Schema only supports these fields:
+ // type, format, title, description, nullable, enum, maxItems, minItems,
+ // properties, required, minProperties, maxProperties, minLength, maxLength,
+ // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
+ const ALLOWED_KEYS: &[&str] = &[
+ "type",
+ "format",
+ "title",
+ "description",
+ "nullable",
+ "enum",
+ "maxItems",
+ "minItems",
+ "properties",
+ "required",
+ "minProperties",
+ "maxProperties",
+ "minLength",
+ "maxLength",
+ "pattern",
+ "example",
+ "anyOf",
+ "propertyOrdering",
+ "default",
+ "items",
+ "minimum",
+ "maximum",
+ ];
+
+ // Convert oneOf to anyOf before filtering keys
+ if let Some(one_of) = obj.remove("oneOf") {
+ obj.insert("anyOf".to_string(), one_of);
+ }
-#[derive(Debug, Serialize, Deserialize)]
-pub enum Task {
- #[serde(rename = "generateContent")]
- GenerateContent,
- #[serde(rename = "streamGenerateContent")]
- StreamGenerateContent,
- #[serde(rename = "countTokens")]
- CountTokens,
- #[serde(rename = "embedContent")]
- EmbedContent,
- #[serde(rename = "batchEmbedContents")]
- BatchEmbedContents,
-}
+ // If type is an array (e.g., ["string", "null"]), take just the first type
+ if let Some(type_field) = obj.get_mut("type") {
+ if let serde_json::Value::Array(types) = type_field {
+ if let Some(first_type) = types.first().cloned() {
+ *type_field = first_type;
+ }
+ }
+ }
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct GenerateContentRequest {
- #[serde(default, skip_serializing_if = "ModelName::is_empty")]
- pub model: ModelName,
- pub contents: Vec<Content>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub system_instruction: Option<SystemInstruction>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub generation_config: Option<GenerationConfig>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub safety_settings: Option<Vec<SafetySetting>>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub tools: Option<Vec<Tool>>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub tool_config: Option<ToolConfig>,
-}
+ // Only filter keys if this is a schema object, not a properties map
+ if is_schema {
+ obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
+ }
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct GenerateContentResponse {
- #[serde(skip_serializing_if = "Option::is_none")]
- pub candidates: Option<Vec<GenerateContentCandidate>>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub prompt_feedback: Option<PromptFeedback>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub usage_metadata: Option<UsageMetadata>,
+ // Recursively process nested values
+ // "properties" contains a map of property names -> schemas
+ // "items" and "anyOf" contain schemas directly
+ for (key, value) in obj.iter_mut() {
+ if key == "properties" {
+ // properties is a map of property_name -> schema
+ if let serde_json::Value::Object(props) = value {
+ for (_, prop_schema) in props.iter_mut() {
+ adapt_schema_for_google_impl(prop_schema, true);
+ }
+ }
+ } else if key == "items" {
+ // items is a schema
+ adapt_schema_for_google_impl(value, true);
+ } else if key == "anyOf" {
+ // anyOf is an array of schemas
+ if let serde_json::Value::Array(arr) = value {
+ for item in arr.iter_mut() {
+ adapt_schema_for_google_impl(item, true);
+ }
+ }
+ }
+ }
+ } else if let serde_json::Value::Array(arr) = json {
+ for item in arr.iter_mut() {
+ adapt_schema_for_google_impl(item, true);
+ }
+ }
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct GenerateContentCandidate {
- #[serde(skip_serializing_if = "Option::is_none")]
- pub index: Option<usize>,
- pub content: Content,
+struct GoogleRequest {
+ contents: Vec<GoogleContent>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub finish_reason: Option<String>,
+ system_instruction: Option<GoogleSystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub finish_message: Option<String>,
+ generation_config: Option<GoogleGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub safety_ratings: Option<Vec<SafetyRating>>,
+ tools: Option<Vec<GoogleTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub citation_metadata: Option<CitationMetadata>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct Content {
- #[serde(default)]
- pub parts: Vec<Part>,
- pub role: Role,
+ tool_config: Option<GoogleToolConfig>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct SystemInstruction {
- pub parts: Vec<Part>,
+struct GoogleSystemInstruction {
+ parts: Vec<GooglePart>,
}
-#[derive(Debug, PartialEq, Deserialize, Serialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub enum Role {
- User,
- Model,
+struct GoogleContent {
+ parts: Vec<GooglePart>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ role: Option<String>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
-pub enum Part {
- TextPart(TextPart),
- InlineDataPart(InlineDataPart),
- FunctionCallPart(FunctionCallPart),
- FunctionResponsePart(FunctionResponsePart),
- ThoughtPart(ThoughtPart),
+enum GooglePart {
+ Text(GoogleTextPart),
+ InlineData(GoogleInlineDataPart),
+ FunctionCall(GoogleFunctionCallPart),
+ FunctionResponse(GoogleFunctionResponsePart),
+ Thought(GoogleThoughtPart),
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct TextPart {
- pub text: String,
+struct GoogleTextPart {
+ text: String,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct InlineDataPart {
- pub inline_data: GenerativeContentBlob,
+struct GoogleInlineDataPart {
+ inline_data: GoogleBlob,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct GenerativeContentBlob {
- pub mime_type: String,
- pub data: String,
+struct GoogleBlob {
+ mime_type: String,
+ data: String,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct FunctionCallPart {
- pub function_call: FunctionCall,
- /// Thought signature returned by the model for function calls.
- /// Only present on the first function call in parallel call scenarios.
+struct GoogleFunctionCallPart {
+ function_call: GoogleFunctionCall,
#[serde(skip_serializing_if = "Option::is_none")]
- pub thought_signature: Option<String>,
+ thought_signature: Option<String>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct FunctionResponsePart {
- pub function_response: FunctionResponse,
+struct GoogleFunctionCall {
+ name: String,
+ args: serde_json::Value,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct ThoughtPart {
- pub thought: bool,
- pub thought_signature: String,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct CitationSource {
- #[serde(skip_serializing_if = "Option::is_none")]
- pub start_index: Option<usize>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub end_index: Option<usize>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub uri: Option<String>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub license: Option<String>,
+struct GoogleFunctionResponsePart {
+ function_response: GoogleFunctionResponse,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct CitationMetadata {
- pub citation_sources: Vec<CitationSource>,
+struct GoogleFunctionResponse {
+ name: String,
+ response: serde_json::Value,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
-pub struct PromptFeedback {
- #[serde(skip_serializing_if = "Option::is_none")]
- pub block_reason: Option<String>,
- pub safety_ratings: Option<Vec<SafetyRating>>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub block_reason_message: Option<String>,
+struct GoogleThoughtPart {
+ thought: bool,
+ thought_signature: String,
}
-#[derive(Debug, Serialize, Deserialize, Default)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct UsageMetadata {
- #[serde(skip_serializing_if = "Option::is_none")]
- pub prompt_token_count: Option<u64>,
+struct GoogleGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
- pub cached_content_token_count: Option<u64>,
+ candidate_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub candidates_token_count: Option<u64>,
+ stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub tool_use_prompt_token_count: Option<u64>,
+ max_output_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub thoughts_token_count: Option<u64>,
+ temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub total_token_count: Option<u64>,
+ thinking_config: Option<GoogleThinkingConfig>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct ThinkingConfig {
- pub thinking_budget: u32,
-}
-
-#[derive(Debug, Deserialize, Serialize)]
-#[serde(rename_all = "camelCase")]
-pub struct GenerationConfig {
- #[serde(skip_serializing_if = "Option::is_none")]
- pub candidate_count: Option<usize>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub stop_sequences: Option<Vec<String>>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub max_output_tokens: Option<usize>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub temperature: Option<f64>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub top_p: Option<f64>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub top_k: Option<usize>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub thinking_config: Option<ThinkingConfig>,
+struct GoogleThinkingConfig {
+ thinking_budget: u32,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct SafetySetting {
- pub category: HarmCategory,
- pub threshold: HarmBlockThreshold,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub enum HarmCategory {
- #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
- Unspecified,
- #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
- Derogatory,
- #[serde(rename = "HARM_CATEGORY_TOXICITY")]
- Toxicity,
- #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
- Violence,
- #[serde(rename = "HARM_CATEGORY_SEXUAL")]
- Sexual,
- #[serde(rename = "HARM_CATEGORY_MEDICAL")]
- Medical,
- #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
- Dangerous,
- #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
- Harassment,
- #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
- HateSpeech,
- #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
- SexuallyExplicit,
- #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
- DangerousContent,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum HarmBlockThreshold {
- #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
- Unspecified,
- BlockLowAndAbove,
- BlockMediumAndAbove,
- BlockOnlyHigh,
- BlockNone,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
-pub enum HarmProbability {
- #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
- Unspecified,
- Negligible,
- Low,
- Medium,
- High,
+struct GoogleTool {
+ function_declarations: Vec<GoogleFunctionDeclaration>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct SafetyRating {
- pub category: HarmCategory,
- pub probability: HarmProbability,
+struct GoogleFunctionDeclaration {
+ name: String,
+ description: String,
+ parameters: serde_json::Value,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct CountTokensRequest {
- pub generate_content_request: GenerateContentRequest,
+struct GoogleToolConfig {
+ function_calling_config: GoogleFunctionCallingConfig,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
-pub struct CountTokensResponse {
- pub total_tokens: u64,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct FunctionCall {
- pub name: String,
- pub args: serde_json::Value,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct FunctionResponse {
- pub name: String,
- pub response: serde_json::Value,
+struct GoogleFunctionCallingConfig {
+ mode: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ allowed_function_names: Option<Vec<String>>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
-pub struct Tool {
- pub function_declarations: Vec<FunctionDeclaration>,
+struct GoogleStreamResponse {
+ #[serde(default)]
+ candidates: Vec<GoogleCandidate>,
+ #[serde(default)]
+ usage_metadata: Option<GoogleUsageMetadata>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
-pub struct ToolConfig {
- pub function_calling_config: FunctionCallingConfig,
+struct GoogleCandidate {
+ #[serde(default)]
+ content: Option<GoogleContent>,
+ #[serde(default)]
+ finish_reason: Option<String>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
-pub struct FunctionCallingConfig {
- pub mode: FunctionCallingMode,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub allowed_function_names: Option<Vec<String>>,
+struct GoogleUsageMetadata {
+ #[serde(default)]
+ prompt_token_count: u64,
+ #[serde(default)]
+ candidates_token_count: u64,
}
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "lowercase")]
-pub enum FunctionCallingMode {
- Auto,
- Any,
- None,
-}
+fn convert_request(
+ model_id: &str,
+ request: &LlmCompletionRequest,
+) -> Result<(GoogleRequest, String), String> {
+ let real_model_id =
+ get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
-#[derive(Debug, Serialize, Deserialize)]
-pub struct FunctionDeclaration {
- pub name: String,
- pub description: String,
- pub parameters: serde_json::Value,
-}
+ let supports_thinking = get_model_supports_thinking(model_id);
-#[derive(Debug, Default)]
-pub struct ModelName {
- pub model_id: String,
-}
+ let mut contents: Vec<GoogleContent> = Vec::new();
+ let mut system_parts: Vec<GooglePart> = Vec::new();
-impl ModelName {
- pub fn is_empty(&self) -> bool {
- self.model_id.is_empty()
- }
-}
+ for msg in &request.messages {
+ match msg.role {
+ LlmMessageRole::System => {
+ for content in &msg.content {
+ if let LlmMessageContent::Text(text) = content {
+ if !text.is_empty() {
+ system_parts
+ .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
+ }
+ }
+ }
+ }
+ LlmMessageRole::User => {
+ let mut parts: Vec<GooglePart> = Vec::new();
+
+ for content in &msg.content {
+ match content {
+ LlmMessageContent::Text(text) => {
+ if !text.is_empty() {
+ parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
+ }
+ }
+ LlmMessageContent::Image(img) => {
+ parts.push(GooglePart::InlineData(GoogleInlineDataPart {
+ inline_data: GoogleBlob {
+ mime_type: "image/png".to_string(),
+ data: img.source.clone(),
+ },
+ }));
+ }
+ LlmMessageContent::ToolResult(result) => {
+ let response_value = match &result.content {
+ LlmToolResultContent::Text(t) => {
+ serde_json::json!({ "output": t })
+ }
+ LlmToolResultContent::Image(_) => {
+ serde_json::json!({ "output": "Tool responded with an image" })
+ }
+ };
+ parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
+ function_response: GoogleFunctionResponse {
+ name: result.tool_name.clone(),
+ response: response_value,
+ },
+ }));
+ }
+ _ => {}
+ }
+ }
-const MODEL_NAME_PREFIX: &str = "models/";
+ if !parts.is_empty() {
+ contents.push(GoogleContent {
+ parts,
+ role: Some("user".to_string()),
+ });
+ }
+ }
+ LlmMessageRole::Assistant => {
+ let mut parts: Vec<GooglePart> = Vec::new();
+
+ for content in &msg.content {
+ match content {
+ LlmMessageContent::Text(text) => {
+ if !text.is_empty() {
+ parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
+ }
+ }
+ LlmMessageContent::ToolUse(tool_use) => {
+ let thought_signature =
+ tool_use.thought_signature.clone().filter(|s| !s.is_empty());
+
+ let args: serde_json::Value =
+ serde_json::from_str(&tool_use.input).unwrap_or_default();
+
+ parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
+ function_call: GoogleFunctionCall {
+ name: tool_use.name.clone(),
+ args,
+ },
+ thought_signature,
+ }));
+ }
+ LlmMessageContent::Thinking(thinking) => {
+ if let Some(ref signature) = thinking.signature {
+ if !signature.is_empty() {
+ parts.push(GooglePart::Thought(GoogleThoughtPart {
+ thought: true,
+ thought_signature: signature.clone(),
+ }));
+ }
+ }
+ }
+ _ => {}
+ }
+ }
-impl Serialize for ModelName {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: Serializer,
- {
- serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
+ if !parts.is_empty() {
+ contents.push(GoogleContent {
+ parts,
+ role: Some("model".to_string()),
+ });
+ }
+ }
+ }
}
-}
-impl<'de> Deserialize<'de> for ModelName {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: Deserializer<'de>,
- {
- let string = String::deserialize(deserializer)?;
- if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
- Ok(Self {
- model_id: id.to_string(),
+ let system_instruction = if system_parts.is_empty() {
+ None
+ } else {
+ Some(GoogleSystemInstruction {
+ parts: system_parts,
+ })
+ };
+
+ let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
+ None
+ } else {
+ let declarations: Vec<GoogleFunctionDeclaration> = request
+ .tools
+ .iter()
+ .map(|t| {
+ let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
+ .unwrap_or(serde_json::Value::Object(Default::default()));
+ adapt_schema_for_google(&mut parameters);
+ GoogleFunctionDeclaration {
+ name: t.name.clone(),
+ description: t.description.clone(),
+ parameters,
+ }
})
- } else {
- Err(serde::de::Error::custom(format!(
- "Expected model name to begin with {}, got: {}",
- MODEL_NAME_PREFIX, string
- )))
+ .collect();
+ Some(vec![GoogleTool {
+ function_declarations: declarations,
+ }])
+ };
+
+ let tool_config = request.tool_choice.as_ref().map(|tc| {
+ let mode = match tc {
+ LlmToolChoice::Auto => "AUTO",
+ LlmToolChoice::Any => "ANY",
+ LlmToolChoice::None => "NONE",
+ };
+ GoogleToolConfig {
+ function_calling_config: GoogleFunctionCallingConfig {
+ mode: mode.to_string(),
+ allowed_function_names: None,
+ },
}
- }
-}
+ });
-#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
-pub enum Model {
- #[serde(
- rename = "gemini-2.5-flash-lite",
- alias = "gemini-2.5-flash-lite-preview-06-17",
- alias = "gemini-2.0-flash-lite-preview"
- )]
- Gemini25FlashLite,
- #[serde(
- rename = "gemini-2.5-flash",
- alias = "gemini-2.0-flash-thinking-exp",
- alias = "gemini-2.5-flash-preview-04-17",
- alias = "gemini-2.5-flash-preview-05-20",
- alias = "gemini-2.5-flash-preview-latest",
- alias = "gemini-2.0-flash"
- )]
- #[default]
- Gemini25Flash,
- #[serde(
- rename = "gemini-2.5-pro",
- alias = "gemini-2.0-pro-exp",
- alias = "gemini-2.5-pro-preview-latest",
- alias = "gemini-2.5-pro-exp-03-25",
- alias = "gemini-2.5-pro-preview-03-25",
- alias = "gemini-2.5-pro-preview-05-06",
- alias = "gemini-2.5-pro-preview-06-05"
- )]
- Gemini25Pro,
- #[serde(rename = "gemini-3-pro-preview")]
- Gemini3Pro,
- #[serde(rename = "custom")]
- Custom {
- name: String,
- /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
- display_name: Option<String>,
- max_tokens: u64,
- #[serde(default)]
- mode: GoogleModelMode,
- },
-}
+ let thinking_config = if supports_thinking && request.thinking_allowed {
+ Some(GoogleThinkingConfig {
+ thinking_budget: 8192,
+ })
+ } else {
+ None
+ };
-impl Model {
- pub fn default_fast() -> Self {
- Self::Gemini25FlashLite
+ let generation_config = Some(GoogleGenerationConfig {
+ candidate_count: Some(1),
+ stop_sequences: if request.stop_sequences.is_empty() {
+ None
+ } else {
+ Some(request.stop_sequences.clone())
+ },
+ max_output_tokens: None,
+ temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+ thinking_config,
+ });
+
+ Ok((
+ GoogleRequest {
+ contents,
+ system_instruction,
+ generation_config,
+ tools,
+ tool_config,
+ },
+ real_model_id.to_string(),
+ ))
+}
+
+fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
+ let trimmed = line.trim();
+ if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
+ return None;
}
- pub fn id(&self) -> &str {
- match self {
- Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
- Self::Gemini25Flash => "gemini-2.5-flash",
- Self::Gemini25Pro => "gemini-2.5-pro",
- Self::Gemini3Pro => "gemini-3-pro-preview",
- Self::Custom { name, .. } => name,
- }
- }
- pub fn request_id(&self) -> &str {
- match self {
- Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
- Self::Gemini25Flash => "gemini-2.5-flash",
- Self::Gemini25Pro => "gemini-2.5-pro",
- Self::Gemini3Pro => "gemini-3-pro-preview",
- Self::Custom { name, .. } => name,
- }
- }
+ let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
+ let json_str = json_str.trim_start_matches(',').trim();
- pub fn display_name(&self) -> &str {
- match self {
- Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
- Self::Gemini25Flash => "Gemini 2.5 Flash",
- Self::Gemini25Pro => "Gemini 2.5 Pro",
- Self::Gemini3Pro => "Gemini 3 Pro",
- Self::Custom {
- name, display_name, ..
- } => display_name.as_ref().unwrap_or(name),
- }
+ if json_str.is_empty() {
+ return None;
}
- pub fn max_token_count(&self) -> u64 {
- match self {
- Self::Gemini25FlashLite => 1_048_576,
- Self::Gemini25Flash => 1_048_576,
- Self::Gemini25Pro => 1_048_576,
- Self::Gemini3Pro => 1_048_576,
- Self::Custom { max_tokens, .. } => *max_tokens,
+ serde_json::from_str(json_str).ok()
+}
+
+impl zed::Extension for GoogleAiProvider {
+ fn new() -> Self {
+ Self {
+ streams: Mutex::new(HashMap::new()),
+ next_stream_id: Mutex::new(0),
}
}
- pub fn max_output_tokens(&self) -> Option<u64> {
- match self {
- Model::Gemini25FlashLite => Some(65_536),
- Model::Gemini25Flash => Some(65_536),
- Model::Gemini25Pro => Some(65_536),
- Model::Gemini3Pro => Some(65_536),
- Model::Custom { .. } => None,
- }
+ fn llm_providers(&self) -> Vec<LlmProviderInfo> {
+ vec![LlmProviderInfo {
+ id: "google-ai".into(),
+ name: "Google AI".into(),
+ icon: Some("icons/google-ai.svg".into()),
+ }]
}
- pub fn supports_tools(&self) -> bool {
- true
+ fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
+ Ok(MODELS
+ .iter()
+ .map(|m| LlmModelInfo {
+ id: m.display_name.to_string(),
+ name: m.display_name.to_string(),
+ max_token_count: m.max_tokens,
+ max_output_tokens: m.max_output_tokens,
+ capabilities: LlmModelCapabilities {
+ supports_images: m.supports_images,
+ supports_tools: true,
+ supports_tool_choice_auto: true,
+ supports_tool_choice_any: true,
+ supports_tool_choice_none: true,
+ supports_thinking: m.supports_thinking,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: m.is_default,
+ is_default_fast: m.is_default_fast,
+ })
+ .collect())
}
- pub fn supports_images(&self) -> bool {
- true
+ fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
+ llm_get_credential("google-ai").is_some()
}
- pub fn mode(&self) -> GoogleModelMode {
- match self {
- Self::Gemini25FlashLite
- | Self::Gemini25Flash
- | Self::Gemini25Pro
- | Self::Gemini3Pro => {
- GoogleModelMode::Thinking {
- // By default these models are set to "auto", so we preserve that behavior
- // but indicate they are capable of thinking mode
- budget_tokens: None,
- }
- }
- Self::Custom { mode, .. } => *mode,
- }
+ fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
+ Some(
+ "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(),
+ )
}
-}
-impl std::fmt::Display for Model {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.id())
+ fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
+ llm_delete_credential("google-ai")
}
-}
-#[cfg(test)]
-mod tests {
- use super::*;
- use serde_json::json;
-
- #[test]
- fn test_function_call_part_with_signature_serializes_correctly() {
- let part = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("test_signature".to_string()),
+ fn llm_stream_completion_start(
+ &mut self,
+ _provider_id: &str,
+ model_id: &str,
+ request: &LlmCompletionRequest,
+ ) -> Result<String, String> {
+ let api_key = llm_get_credential("google-ai").ok_or_else(|| {
+ "No API key configured. Please add your Google AI API key in settings.".to_string()
+ })?;
+
+ let (google_request, real_model_id) = convert_request(model_id, request)?;
+
+ let body = serde_json::to_vec(&google_request)
+ .map_err(|e| format!("Failed to serialize request: {}", e))?;
+
+ let url = format!(
+ "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
+ real_model_id, api_key
+ );
+
+ let http_request = HttpRequest {
+ method: HttpMethod::Post,
+ url,
+ headers: vec![("Content-Type".to_string(), "application/json".to_string())],
+ body: Some(body),
+ redirect_policy: RedirectPolicy::FollowAll,
};
- let serialized = serde_json::to_value(&part).unwrap();
-
- assert_eq!(serialized["functionCall"]["name"], "test_function");
- assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
- assert_eq!(serialized["thoughtSignature"], "test_signature");
- }
+ let response_stream = http_request
+ .fetch_stream()
+ .map_err(|e| format!("HTTP request failed: {}", e))?;
- #[test]
- fn test_function_call_part_without_signature_omits_field() {
- let part = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: None,
+ let stream_id = {
+ let mut id_counter = self.next_stream_id.lock().unwrap();
+ let id = format!("google-ai-stream-{}", *id_counter);
+ *id_counter += 1;
+ id
};
- let serialized = serde_json::to_value(&part).unwrap();
+ self.streams.lock().unwrap().insert(
+ stream_id.clone(),
+ StreamState {
+ response_stream: Some(response_stream),
+ buffer: String::new(),
+ started: false,
+ stop_reason: None,
+ wants_tool_use: false,
+ },
+ );
- assert_eq!(serialized["functionCall"]["name"], "test_function");
- assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
- // thoughtSignature field should not be present when None
- assert!(serialized.get("thoughtSignature").is_none());
+ Ok(stream_id)
}
- #[test]
- fn test_function_call_part_deserializes_with_signature() {
- let json = json!({
- "functionCall": {
- "name": "test_function",
- "args": {"arg": "value"}
- },
- "thoughtSignature": "test_signature"
- });
+ fn llm_stream_completion_next(
+ &mut self,
+ stream_id: &str,
+ ) -> Result<Option<LlmCompletionEvent>, String> {
+ let mut streams = self.streams.lock().unwrap();
+ let state = streams
+ .get_mut(stream_id)
+ .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
+
+ if !state.started {
+ state.started = true;
+ return Ok(Some(LlmCompletionEvent::Started));
+ }
- let part: FunctionCallPart = serde_json::from_value(json).unwrap();
+ let response_stream = state
+ .response_stream
+ .as_mut()
+ .ok_or_else(|| "Stream already closed".to_string())?;
+
+ loop {
+ if let Some(newline_pos) = state.buffer.find('\n') {
+ let line = state.buffer[..newline_pos].to_string();
+ state.buffer = state.buffer[newline_pos + 1..].to_string();
+
+ if let Some(response) = parse_stream_line(&line) {
+ for candidate in response.candidates {
+ if let Some(finish_reason) = &candidate.finish_reason {
+ state.stop_reason = Some(match finish_reason.as_str() {
+ "STOP" => {
+ if state.wants_tool_use {
+ LlmStopReason::ToolUse
+ } else {
+ LlmStopReason::EndTurn
+ }
+ }
+ "MAX_TOKENS" => LlmStopReason::MaxTokens,
+ "SAFETY" => LlmStopReason::Refusal,
+ _ => LlmStopReason::EndTurn,
+ });
+ }
- assert_eq!(part.function_call.name, "test_function");
- assert_eq!(part.thought_signature, Some("test_signature".to_string()));
- }
+ if let Some(content) = candidate.content {
+ for part in content.parts {
+ match part {
+ GooglePart::Text(text_part) => {
+ if !text_part.text.is_empty() {
+ return Ok(Some(LlmCompletionEvent::Text(
+ text_part.text,
+ )));
+ }
+ }
+ GooglePart::FunctionCall(fc_part) => {
+ state.wants_tool_use = true;
+ let next_tool_id =
+ TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
+ let id = format!(
+ "{}-{}",
+ fc_part.function_call.name, next_tool_id
+ );
+
+ let thought_signature =
+ fc_part.thought_signature.filter(|s| !s.is_empty());
+
+ return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
+ id,
+ name: fc_part.function_call.name,
+ input: fc_part.function_call.args.to_string(),
+ is_input_complete: true,
+ thought_signature,
+ })));
+ }
+ GooglePart::Thought(thought_part) => {
+ return Ok(Some(LlmCompletionEvent::Thinking(
+ LlmThinkingContent {
+ text: "(Encrypted thought)".to_string(),
+ signature: Some(thought_part.thought_signature),
+ },
+ )));
+ }
+ _ => {}
+ }
+ }
+ }
+ }
- #[test]
- fn test_function_call_part_deserializes_without_signature() {
- let json = json!({
- "functionCall": {
- "name": "test_function",
- "args": {"arg": "value"}
- }
- });
+ if let Some(usage) = response.usage_metadata {
+ return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
+ input_tokens: usage.prompt_token_count,
+ output_tokens: usage.candidates_token_count,
+ cache_creation_input_tokens: None,
+ cache_read_input_tokens: None,
+ })));
+ }
+ }
- let part: FunctionCallPart = serde_json::from_value(json).unwrap();
+ continue;
+ }
- assert_eq!(part.function_call.name, "test_function");
- assert_eq!(part.thought_signature, None);
- }
+ match response_stream.next_chunk() {
+ Ok(Some(chunk)) => {
+ let text = String::from_utf8_lossy(&chunk);
+ state.buffer.push_str(&text);
+ }
+ Ok(None) => {
+ // Stream ended - check if we have a stop reason
+ if let Some(stop_reason) = state.stop_reason.take() {
+ return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
+ }
- #[test]
- fn test_function_call_part_round_trip() {
- let original = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value", "nested": {"key": "val"}}),
- },
- thought_signature: Some("round_trip_signature".to_string()),
- };
+ // No stop reason - this is unexpected. Check if buffer contains error info
+ let mut error_msg = String::from("Stream ended unexpectedly.");
- let serialized = serde_json::to_value(&original).unwrap();
- let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
+ // Try to parse remaining buffer as potential error response
+ if !state.buffer.is_empty() {
+ error_msg.push_str(&format!(
+ "\nRemaining buffer: {}",
+ &state.buffer[..state.buffer.len().min(1000)]
+ ));
+ }
- assert_eq!(deserialized.function_call.name, original.function_call.name);
- assert_eq!(deserialized.function_call.args, original.function_call.args);
- assert_eq!(deserialized.thought_signature, original.thought_signature);
+ return Err(error_msg);
+ }
+ Err(e) => {
+ return Err(format!("Stream error: {}", e));
+ }
+ }
+ }
}
- #[test]
- fn test_function_call_part_with_empty_signature_serializes() {
- let part = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("".to_string()),
- };
-
- let serialized = serde_json::to_value(&part).unwrap();
-
- // Empty string should still be serialized (normalization happens at a higher level)
- assert_eq!(serialized["thoughtSignature"], "");
+ fn llm_stream_completion_close(&mut self, stream_id: &str) {
+ self.streams.lock().unwrap().remove(stream_id);
}
}
+
+zed::register_extension!(GoogleAiProvider);