@@ -1,798 +1,718 @@
-use std::collections::HashMap;
-use std::sync::atomic::{AtomicU64, Ordering};
-use std::sync::Mutex;
-
-use serde::{Deserialize, Serialize};
-use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
-use zed_extension_api::{self as zed, *};
-
-static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
-
-struct GoogleAiProvider {
- streams: Mutex<HashMap<String, StreamState>>,
- next_stream_id: Mutex<u64>,
-}
-
-struct StreamState {
- response_stream: Option<HttpResponseStream>,
- buffer: String,
- started: bool,
- stop_reason: Option<LlmStopReason>,
- wants_tool_use: bool,
-}
-
-struct ModelDefinition {
- real_id: &'static str,
- display_name: &'static str,
- max_tokens: u64,
- max_output_tokens: Option<u64>,
- supports_images: bool,
- supports_thinking: bool,
- is_default: bool,
- is_default_fast: bool,
-}
-
-const MODELS: &[ModelDefinition] = &[
- ModelDefinition {
- real_id: "gemini-2.5-flash-lite",
- display_name: "Gemini 2.5 Flash-Lite",
- max_tokens: 1_048_576,
- max_output_tokens: Some(65_536),
- supports_images: true,
- supports_thinking: true,
- is_default: false,
- is_default_fast: true,
- },
- ModelDefinition {
- real_id: "gemini-2.5-flash",
- display_name: "Gemini 2.5 Flash",
- max_tokens: 1_048_576,
- max_output_tokens: Some(65_536),
- supports_images: true,
- supports_thinking: true,
- is_default: true,
- is_default_fast: false,
- },
- ModelDefinition {
- real_id: "gemini-2.5-pro",
- display_name: "Gemini 2.5 Pro",
- max_tokens: 1_048_576,
- max_output_tokens: Some(65_536),
- supports_images: true,
- supports_thinking: true,
- is_default: false,
- is_default_fast: false,
- },
- ModelDefinition {
- real_id: "gemini-3-pro-preview",
- display_name: "Gemini 3 Pro",
- max_tokens: 1_048_576,
- max_output_tokens: Some(65_536),
- supports_images: true,
- supports_thinking: true,
- is_default: false,
- is_default_fast: false,
- },
-];
-fn get_real_model_id(display_name: &str) -> Option<&'static str> {
- MODELS
- .iter()
- .find(|m| m.display_name == display_name)
- .map(|m| m.real_id)
+use std::mem;
+
+use anyhow::{Result, anyhow, bail};
+use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+pub use settings::ModelMode as GoogleModelMode;
+
+pub const API_URL: &str = "https://generativelanguage.googleapis.com";
+
+pub async fn stream_generate_content(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ mut request: GenerateContentRequest,
+) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
+ let api_key = api_key.trim();
+ validate_generate_content_request(&request)?;
+
+ // The `model` field is emptied as it is provided as a path parameter.
+ let model_id = mem::take(&mut request.model.model_id);
+
+ let uri =
+ format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json");
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ if let Some(line) = line.strip_prefix("data: ") {
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => Some(Err(anyhow!(format!(
+ "Error parsing JSON: {error:?}\n{line:?}"
+ )))),
+ }
+ } else {
+ None
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut text = String::new();
+ response.body_mut().read_to_string(&mut text).await?;
+ Err(anyhow!(
+ "error during streamGenerateContent, status code: {:?}, body: {}",
+ response.status(),
+ text
+ ))
+ }
}
-fn get_model_supports_thinking(display_name: &str) -> bool {
- MODELS
- .iter()
- .find(|m| m.display_name == display_name)
- .map(|m| m.supports_thinking)
- .unwrap_or(false)
-}
-
-/// Adapts a JSON schema to be compatible with Google's API subset.
-/// Google only supports a specific subset of JSON Schema fields.
-/// See: https://ai.google.dev/api/caching#Schema
-fn adapt_schema_for_google(json: &mut serde_json::Value) {
- adapt_schema_for_google_impl(json, true);
-}
-
-fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
- if let serde_json::Value::Object(obj) = json {
- // Google's Schema only supports these fields:
- // type, format, title, description, nullable, enum, maxItems, minItems,
- // properties, required, minProperties, maxProperties, minLength, maxLength,
- // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
- const ALLOWED_KEYS: &[&str] = &[
- "type",
- "format",
- "title",
- "description",
- "nullable",
- "enum",
- "maxItems",
- "minItems",
- "properties",
- "required",
- "minProperties",
- "maxProperties",
- "minLength",
- "maxLength",
- "pattern",
- "example",
- "anyOf",
- "propertyOrdering",
- "default",
- "items",
- "minimum",
- "maximum",
- ];
-
- // Convert oneOf to anyOf before filtering keys
- if let Some(one_of) = obj.remove("oneOf") {
- obj.insert("anyOf".to_string(), one_of);
- }
+pub async fn count_tokens(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: CountTokensRequest,
+) -> Result<CountTokensResponse> {
+ validate_generate_content_request(&request.generate_content_request)?;
+
+ let uri = format!(
+ "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
+ model_id = &request.generate_content_request.model.model_id,
+ );
+
+ let request = serde_json::to_string(&request)?;
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(&uri)
+ .header("Content-Type", "application/json");
+ let http_request = request_builder.body(AsyncBody::from(request))?;
+
+ let mut response = client.send(http_request).await?;
+ let mut text = String::new();
+ response.body_mut().read_to_string(&mut text).await?;
+ anyhow::ensure!(
+ response.status().is_success(),
+ "error during countTokens, status code: {:?}, body: {}",
+ response.status(),
+ text
+ );
+ Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+}
- // If type is an array (e.g., ["string", "null"]), take just the first type
- if let Some(type_field) = obj.get_mut("type") {
- if let serde_json::Value::Array(types) = type_field {
- if let Some(first_type) = types.first().cloned() {
- *type_field = first_type;
- }
- }
- }
+pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
+ if request.model.is_empty() {
+ bail!("Model must be specified");
+ }
- // Only filter keys if this is a schema object, not a properties map
- if is_schema {
- obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
- }
+ if request.contents.is_empty() {
+ bail!("Request must contain at least one content item");
+ }
- // Recursively process nested values
- // "properties" contains a map of property names -> schemas
- // "items" and "anyOf" contain schemas directly
- for (key, value) in obj.iter_mut() {
- if key == "properties" {
- // properties is a map of property_name -> schema
- if let serde_json::Value::Object(props) = value {
- for (_, prop_schema) in props.iter_mut() {
- adapt_schema_for_google_impl(prop_schema, true);
- }
- }
- } else if key == "items" {
- // items is a schema
- adapt_schema_for_google_impl(value, true);
- } else if key == "anyOf" {
- // anyOf is an array of schemas
- if let serde_json::Value::Array(arr) = value {
- for item in arr.iter_mut() {
- adapt_schema_for_google_impl(item, true);
- }
- }
- }
- }
- } else if let serde_json::Value::Array(arr) = json {
- for item in arr.iter_mut() {
- adapt_schema_for_google_impl(item, true);
- }
+ if let Some(user_content) = request
+ .contents
+ .iter()
+ .find(|content| content.role == Role::User)
+ && user_content.parts.is_empty()
+ {
+ bail!("User content must contain at least one part");
}
+
+ Ok(())
}
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Task {
+ #[serde(rename = "generateContent")]
+ GenerateContent,
+ #[serde(rename = "streamGenerateContent")]
+ StreamGenerateContent,
+ #[serde(rename = "countTokens")]
+ CountTokens,
+ #[serde(rename = "embedContent")]
+ EmbedContent,
+ #[serde(rename = "batchEmbedContents")]
+ BatchEmbedContents,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleRequest {
- contents: Vec<GoogleContent>,
+pub struct GenerateContentRequest {
+ #[serde(default, skip_serializing_if = "ModelName::is_empty")]
+ pub model: ModelName,
+ pub contents: Vec<Content>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub system_instruction: Option<SystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
- system_instruction: Option<GoogleSystemInstruction>,
+ pub generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
- generation_config: Option<GoogleGenerationConfig>,
+ pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
- tools: Option<Vec<GoogleTool>>,
+ pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
- tool_config: Option<GoogleToolConfig>,
+ pub tool_config: Option<ToolConfig>,
}
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleSystemInstruction {
- parts: Vec<GooglePart>,
+pub struct GenerateContentResponse {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub candidates: Option<Vec<GenerateContentCandidate>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub prompt_feedback: Option<PromptFeedback>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub usage_metadata: Option<UsageMetadata>,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleContent {
- parts: Vec<GooglePart>,
+pub struct GenerateContentCandidate {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub index: Option<usize>,
+ pub content: Content,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub finish_reason: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub finish_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
- role: Option<String>,
+ pub safety_ratings: Option<Vec<SafetyRating>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub citation_metadata: Option<CitationMetadata>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Content {
+ #[serde(default)]
+ pub parts: Vec<Part>,
+ pub role: Role,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SystemInstruction {
+ pub parts: Vec<Part>,
+}
+
+#[derive(Debug, PartialEq, Deserialize, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum Role {
+ User,
+ Model,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
-enum GooglePart {
- Text(GoogleTextPart),
- InlineData(GoogleInlineDataPart),
- FunctionCall(GoogleFunctionCallPart),
- FunctionResponse(GoogleFunctionResponsePart),
- Thought(GoogleThoughtPart),
+pub enum Part {
+ TextPart(TextPart),
+ InlineDataPart(InlineDataPart),
+ FunctionCallPart(FunctionCallPart),
+ FunctionResponsePart(FunctionResponsePart),
+ ThoughtPart(ThoughtPart),
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleTextPart {
- text: String,
+pub struct TextPart {
+ pub text: String,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleInlineDataPart {
- inline_data: GoogleBlob,
+pub struct InlineDataPart {
+ pub inline_data: GenerativeContentBlob,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleBlob {
- mime_type: String,
- data: String,
+pub struct GenerativeContentBlob {
+ pub mime_type: String,
+ pub data: String,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleFunctionCallPart {
- function_call: GoogleFunctionCall,
+pub struct FunctionCallPart {
+ pub function_call: FunctionCall,
+ /// Thought signature returned by the model for function calls.
+ /// Only present on the first function call in parallel call scenarios.
#[serde(skip_serializing_if = "Option::is_none")]
- thought_signature: Option<String>,
+ pub thought_signature: Option<String>,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleFunctionCall {
- name: String,
- args: serde_json::Value,
+pub struct FunctionResponsePart {
+ pub function_response: FunctionResponse,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleFunctionResponsePart {
- function_response: GoogleFunctionResponse,
+pub struct ThoughtPart {
+ pub thought: bool,
+ pub thought_signature: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CitationSource {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub start_index: Option<usize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub end_index: Option<usize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub uri: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub license: Option<String>,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleFunctionResponse {
- name: String,
- response: serde_json::Value,
+pub struct CitationMetadata {
+ pub citation_sources: Vec<CitationSource>,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleThoughtPart {
- thought: bool,
- thought_signature: String,
+pub struct PromptFeedback {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub block_reason: Option<String>,
+ pub safety_ratings: Option<Vec<SafetyRating>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub block_reason_message: Option<String>,
}
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
-struct GoogleGenerationConfig {
+pub struct UsageMetadata {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub prompt_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
- candidate_count: Option<usize>,
+ pub cached_content_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
- stop_sequences: Option<Vec<String>>,
+ pub candidates_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
- max_output_tokens: Option<usize>,
+ pub tool_use_prompt_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
- temperature: Option<f64>,
+ pub thoughts_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
- thinking_config: Option<GoogleThinkingConfig>,
+ pub total_token_count: Option<u64>,
}
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleThinkingConfig {
- thinking_budget: u32,
+pub struct ThinkingConfig {
+ pub thinking_budget: u32,
}
-#[derive(Serialize)]
+#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleTool {
- function_declarations: Vec<GoogleFunctionDeclaration>,
+pub struct GenerationConfig {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub candidate_count: Option<usize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stop_sequences: Option<Vec<String>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub max_output_tokens: Option<usize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub temperature: Option<f64>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_p: Option<f64>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_k: Option<usize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub thinking_config: Option<ThinkingConfig>,
}
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleFunctionDeclaration {
- name: String,
- description: String,
- parameters: serde_json::Value,
+pub struct SafetySetting {
+ pub category: HarmCategory,
+ pub threshold: HarmBlockThreshold,
}
-#[derive(Serialize)]
-#[serde(rename_all = "camelCase")]
-struct GoogleToolConfig {
- function_calling_config: GoogleFunctionCallingConfig,
+#[derive(Debug, Serialize, Deserialize)]
+pub enum HarmCategory {
+ #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
+ Unspecified,
+ #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
+ Derogatory,
+ #[serde(rename = "HARM_CATEGORY_TOXICITY")]
+ Toxicity,
+ #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
+ Violence,
+ #[serde(rename = "HARM_CATEGORY_SEXUAL")]
+ Sexual,
+ #[serde(rename = "HARM_CATEGORY_MEDICAL")]
+ Medical,
+ #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
+ Dangerous,
+ #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
+ Harassment,
+ #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
+ HateSpeech,
+ #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
+ SexuallyExplicit,
+ #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
+ DangerousContent,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum HarmBlockThreshold {
+ #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
+ Unspecified,
+ BlockLowAndAbove,
+ BlockMediumAndAbove,
+ BlockOnlyHigh,
+ BlockNone,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum HarmProbability {
+ #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
+ Unspecified,
+ Negligible,
+ Low,
+ Medium,
+ High,
}
-#[derive(Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleFunctionCallingConfig {
- mode: String,
- #[serde(skip_serializing_if = "Option::is_none")]
- allowed_function_names: Option<Vec<String>>,
+pub struct SafetyRating {
+ pub category: HarmCategory,
+ pub probability: HarmProbability,
}
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleStreamResponse {
- #[serde(default)]
- candidates: Vec<GoogleCandidate>,
- #[serde(default)]
- usage_metadata: Option<GoogleUsageMetadata>,
+pub struct CountTokensRequest {
+ pub generate_content_request: GenerateContentRequest,
}
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleCandidate {
- #[serde(default)]
- content: Option<GoogleContent>,
- #[serde(default)]
- finish_reason: Option<String>,
+pub struct CountTokensResponse {
+ pub total_tokens: u64,
}
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionCall {
+ pub name: String,
+ pub args: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionResponse {
+ pub name: String,
+ pub response: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
-struct GoogleUsageMetadata {
- #[serde(default)]
- prompt_token_count: u64,
- #[serde(default)]
- candidates_token_count: u64,
+pub struct Tool {
+ pub function_declarations: Vec<FunctionDeclaration>,
}
-fn convert_request(
- model_id: &str,
- request: &LlmCompletionRequest,
-) -> Result<(GoogleRequest, String), String> {
- let real_model_id =
- get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ToolConfig {
+ pub function_calling_config: FunctionCallingConfig,
+}
- let supports_thinking = get_model_supports_thinking(model_id);
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct FunctionCallingConfig {
+ pub mode: FunctionCallingMode,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub allowed_function_names: Option<Vec<String>>,
+}
- let mut contents: Vec<GoogleContent> = Vec::new();
- let mut system_parts: Vec<GooglePart> = Vec::new();
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum FunctionCallingMode {
+ Auto,
+ Any,
+ None,
+}
- for msg in &request.messages {
- match msg.role {
- LlmMessageRole::System => {
- for content in &msg.content {
- if let LlmMessageContent::Text(text) = content {
- if !text.is_empty() {
- system_parts
- .push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
- }
- }
- }
- }
- LlmMessageRole::User => {
- let mut parts: Vec<GooglePart> = Vec::new();
-
- for content in &msg.content {
- match content {
- LlmMessageContent::Text(text) => {
- if !text.is_empty() {
- parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
- }
- }
- LlmMessageContent::Image(img) => {
- parts.push(GooglePart::InlineData(GoogleInlineDataPart {
- inline_data: GoogleBlob {
- mime_type: "image/png".to_string(),
- data: img.source.clone(),
- },
- }));
- }
- LlmMessageContent::ToolResult(result) => {
- let response_value = match &result.content {
- LlmToolResultContent::Text(t) => {
- serde_json::json!({ "output": t })
- }
- LlmToolResultContent::Image(_) => {
- serde_json::json!({ "output": "Tool responded with an image" })
- }
- };
- parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
- function_response: GoogleFunctionResponse {
- name: result.tool_name.clone(),
- response: response_value,
- },
- }));
- }
- _ => {}
- }
- }
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionDeclaration {
+ pub name: String,
+ pub description: String,
+ pub parameters: serde_json::Value,
+}
- if !parts.is_empty() {
- contents.push(GoogleContent {
- parts,
- role: Some("user".to_string()),
- });
- }
- }
- LlmMessageRole::Assistant => {
- let mut parts: Vec<GooglePart> = Vec::new();
-
- for content in &msg.content {
- match content {
- LlmMessageContent::Text(text) => {
- if !text.is_empty() {
- parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
- }
- }
- LlmMessageContent::ToolUse(tool_use) => {
- let thought_signature =
- tool_use.thought_signature.clone().filter(|s| !s.is_empty());
-
- let args: serde_json::Value =
- serde_json::from_str(&tool_use.input).unwrap_or_default();
-
- parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
- function_call: GoogleFunctionCall {
- name: tool_use.name.clone(),
- args,
- },
- thought_signature,
- }));
- }
- LlmMessageContent::Thinking(thinking) => {
- if let Some(ref signature) = thinking.signature {
- if !signature.is_empty() {
- parts.push(GooglePart::Thought(GoogleThoughtPart {
- thought: true,
- thought_signature: signature.clone(),
- }));
- }
- }
- }
- _ => {}
- }
- }
+#[derive(Debug, Default)]
+pub struct ModelName {
+ pub model_id: String,
+}
- if !parts.is_empty() {
- contents.push(GoogleContent {
- parts,
- role: Some("model".to_string()),
- });
- }
- }
- }
+impl ModelName {
+ pub fn is_empty(&self) -> bool {
+ self.model_id.is_empty()
}
+}
- let system_instruction = if system_parts.is_empty() {
- None
- } else {
- Some(GoogleSystemInstruction {
- parts: system_parts,
- })
- };
+const MODEL_NAME_PREFIX: &str = "models/";
- let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
- None
- } else {
- let declarations: Vec<GoogleFunctionDeclaration> = request
- .tools
- .iter()
- .map(|t| {
- let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
- .unwrap_or(serde_json::Value::Object(Default::default()));
- adapt_schema_for_google(&mut parameters);
- GoogleFunctionDeclaration {
- name: t.name.clone(),
- description: t.description.clone(),
- parameters,
- }
+impl Serialize for ModelName {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
+ }
+}
+
+impl<'de> Deserialize<'de> for ModelName {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let string = String::deserialize(deserializer)?;
+ if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
+ Ok(Self {
+ model_id: id.to_string(),
})
- .collect();
- Some(vec![GoogleTool {
- function_declarations: declarations,
- }])
- };
-
- let tool_config = request.tool_choice.as_ref().map(|tc| {
- let mode = match tc {
- LlmToolChoice::Auto => "AUTO",
- LlmToolChoice::Any => "ANY",
- LlmToolChoice::None => "NONE",
- };
- GoogleToolConfig {
- function_calling_config: GoogleFunctionCallingConfig {
- mode: mode.to_string(),
- allowed_function_names: None,
- },
+ } else {
+ Err(serde::de::Error::custom(format!(
+ "Expected model name to begin with {}, got: {}",
+ MODEL_NAME_PREFIX, string
+ )))
}
- });
+ }
+}
- let thinking_config = if supports_thinking && request.thinking_allowed {
- Some(GoogleThinkingConfig {
- thinking_budget: 8192,
- })
- } else {
- None
- };
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
+pub enum Model {
+ #[serde(
+ rename = "gemini-2.5-flash-lite",
+ alias = "gemini-2.5-flash-lite-preview-06-17",
+ alias = "gemini-2.0-flash-lite-preview"
+ )]
+ Gemini25FlashLite,
+ #[serde(
+ rename = "gemini-2.5-flash",
+ alias = "gemini-2.0-flash-thinking-exp",
+ alias = "gemini-2.5-flash-preview-04-17",
+ alias = "gemini-2.5-flash-preview-05-20",
+ alias = "gemini-2.5-flash-preview-latest",
+ alias = "gemini-2.0-flash"
+ )]
+ #[default]
+ Gemini25Flash,
+ #[serde(
+ rename = "gemini-2.5-pro",
+ alias = "gemini-2.0-pro-exp",
+ alias = "gemini-2.5-pro-preview-latest",
+ alias = "gemini-2.5-pro-exp-03-25",
+ alias = "gemini-2.5-pro-preview-03-25",
+ alias = "gemini-2.5-pro-preview-05-06",
+ alias = "gemini-2.5-pro-preview-06-05"
+ )]
+ Gemini25Pro,
+ #[serde(rename = "gemini-3-pro-preview")]
+ Gemini3Pro,
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+ display_name: Option<String>,
+ max_tokens: u64,
+ #[serde(default)]
+ mode: GoogleModelMode,
+ },
+}
- let generation_config = Some(GoogleGenerationConfig {
- candidate_count: Some(1),
- stop_sequences: if request.stop_sequences.is_empty() {
- None
- } else {
- Some(request.stop_sequences.clone())
- },
- max_output_tokens: None,
- temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
- thinking_config,
- });
-
- Ok((
- GoogleRequest {
- contents,
- system_instruction,
- generation_config,
- tools,
- tool_config,
- },
- real_model_id.to_string(),
- ))
-}
-
-fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
- let trimmed = line.trim();
- if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
- return None;
+impl Model {
+ pub fn default_fast() -> Self {
+ Self::Gemini25FlashLite
}
- let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
- let json_str = json_str.trim_start_matches(',').trim();
-
- if json_str.is_empty() {
- return None;
+ pub fn id(&self) -> &str {
+ match self {
+ Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
+ Self::Gemini25Flash => "gemini-2.5-flash",
+ Self::Gemini25Pro => "gemini-2.5-pro",
+ Self::Gemini3Pro => "gemini-3-pro-preview",
+ Self::Custom { name, .. } => name,
+ }
+ }
+ pub fn request_id(&self) -> &str {
+ match self {
+ Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
+ Self::Gemini25Flash => "gemini-2.5-flash",
+ Self::Gemini25Pro => "gemini-2.5-pro",
+ Self::Gemini3Pro => "gemini-3-pro-preview",
+ Self::Custom { name, .. } => name,
+ }
}
- serde_json::from_str(json_str).ok()
-}
+ pub fn display_name(&self) -> &str {
+ match self {
+ Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
+ Self::Gemini25Flash => "Gemini 2.5 Flash",
+ Self::Gemini25Pro => "Gemini 2.5 Pro",
+ Self::Gemini3Pro => "Gemini 3 Pro",
+ Self::Custom {
+ name, display_name, ..
+ } => display_name.as_ref().unwrap_or(name),
+ }
+ }
-impl zed::Extension for GoogleAiProvider {
- fn new() -> Self {
- Self {
- streams: Mutex::new(HashMap::new()),
- next_stream_id: Mutex::new(0),
+ pub fn max_token_count(&self) -> u64 {
+ match self {
+ Self::Gemini25FlashLite => 1_048_576,
+ Self::Gemini25Flash => 1_048_576,
+ Self::Gemini25Pro => 1_048_576,
+ Self::Gemini3Pro => 1_048_576,
+ Self::Custom { max_tokens, .. } => *max_tokens,
}
}
- fn llm_providers(&self) -> Vec<LlmProviderInfo> {
- vec![LlmProviderInfo {
- id: "google-ai".into(),
- name: "Google AI".into(),
- icon: Some("icons/google-ai.svg".into()),
- }]
+ pub fn max_output_tokens(&self) -> Option<u64> {
+ match self {
+ Model::Gemini25FlashLite => Some(65_536),
+ Model::Gemini25Flash => Some(65_536),
+ Model::Gemini25Pro => Some(65_536),
+ Model::Gemini3Pro => Some(65_536),
+ Model::Custom { .. } => None,
+ }
}
- fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
- Ok(MODELS
- .iter()
- .map(|m| LlmModelInfo {
- id: m.display_name.to_string(),
- name: m.display_name.to_string(),
- max_token_count: m.max_tokens,
- max_output_tokens: m.max_output_tokens,
- capabilities: LlmModelCapabilities {
- supports_images: m.supports_images,
- supports_tools: true,
- supports_tool_choice_auto: true,
- supports_tool_choice_any: true,
- supports_tool_choice_none: true,
- supports_thinking: m.supports_thinking,
- tool_input_format: LlmToolInputFormat::JsonSchema,
- },
- is_default: m.is_default,
- is_default_fast: m.is_default_fast,
- })
- .collect())
+ pub fn supports_tools(&self) -> bool {
+ true
}
- fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
- llm_get_credential("google-ai").is_some()
+ pub fn supports_images(&self) -> bool {
+ true
}
- fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
- Some(
- "[Create an API key](https://aistudio.google.com/apikey) to use Google AI as your LLM provider.".to_string(),
- )
+ pub fn mode(&self) -> GoogleModelMode {
+ match self {
+ Self::Gemini25FlashLite
+ | Self::Gemini25Flash
+ | Self::Gemini25Pro
+ | Self::Gemini3Pro => {
+ GoogleModelMode::Thinking {
+ // By default these models are set to "auto", so we preserve that behavior
+ // but indicate they are capable of thinking mode
+ budget_tokens: None,
+ }
+ }
+ Self::Custom { mode, .. } => *mode,
+ }
}
+}
- fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
- llm_delete_credential("google-ai")
+impl std::fmt::Display for Model {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.id())
}
+}
- fn llm_stream_completion_start(
- &mut self,
- _provider_id: &str,
- model_id: &str,
- request: &LlmCompletionRequest,
- ) -> Result<String, String> {
- let api_key = llm_get_credential("google-ai").ok_or_else(|| {
- "No API key configured. Please add your Google AI API key in settings.".to_string()
- })?;
-
- let (google_request, real_model_id) = convert_request(model_id, request)?;
-
- let body = serde_json::to_vec(&google_request)
- .map_err(|e| format!("Failed to serialize request: {}", e))?;
-
- let url = format!(
- "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
- real_model_id, api_key
- );
-
- let http_request = HttpRequest {
- method: HttpMethod::Post,
- url,
- headers: vec![("Content-Type".to_string(), "application/json".to_string())],
- body: Some(body),
- redirect_policy: RedirectPolicy::FollowAll,
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use serde_json::json;
+
+ #[test]
+ fn test_function_call_part_with_signature_serializes_correctly() {
+ let part = FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
+ },
+ thought_signature: Some("test_signature".to_string()),
};
- let response_stream = http_request
- .fetch_stream()
- .map_err(|e| format!("HTTP request failed: {}", e))?;
+ let serialized = serde_json::to_value(&part).unwrap();
- let stream_id = {
- let mut id_counter = self.next_stream_id.lock().unwrap();
- let id = format!("google-ai-stream-{}", *id_counter);
- *id_counter += 1;
- id
- };
+ assert_eq!(serialized["functionCall"]["name"], "test_function");
+ assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
+ assert_eq!(serialized["thoughtSignature"], "test_signature");
+ }
- self.streams.lock().unwrap().insert(
- stream_id.clone(),
- StreamState {
- response_stream: Some(response_stream),
- buffer: String::new(),
- started: false,
- stop_reason: None,
- wants_tool_use: false,
+ #[test]
+ fn test_function_call_part_without_signature_omits_field() {
+ let part = FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
},
- );
+ thought_signature: None,
+ };
+
+ let serialized = serde_json::to_value(&part).unwrap();
- Ok(stream_id)
+ assert_eq!(serialized["functionCall"]["name"], "test_function");
+ assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
+ // thoughtSignature field should not be present when None
+ assert!(serialized.get("thoughtSignature").is_none());
}
- fn llm_stream_completion_next(
- &mut self,
- stream_id: &str,
- ) -> Result<Option<LlmCompletionEvent>, String> {
- let mut streams = self.streams.lock().unwrap();
- let state = streams
- .get_mut(stream_id)
- .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
-
- if !state.started {
- state.started = true;
- return Ok(Some(LlmCompletionEvent::Started));
- }
+ #[test]
+ fn test_function_call_part_deserializes_with_signature() {
+ let json = json!({
+ "functionCall": {
+ "name": "test_function",
+ "args": {"arg": "value"}
+ },
+ "thoughtSignature": "test_signature"
+ });
- let response_stream = state
- .response_stream
- .as_mut()
- .ok_or_else(|| "Stream already closed".to_string())?;
-
- loop {
- if let Some(newline_pos) = state.buffer.find('\n') {
- let line = state.buffer[..newline_pos].to_string();
- state.buffer = state.buffer[newline_pos + 1..].to_string();
-
- if let Some(response) = parse_stream_line(&line) {
- for candidate in response.candidates {
- if let Some(finish_reason) = &candidate.finish_reason {
- state.stop_reason = Some(match finish_reason.as_str() {
- "STOP" => {
- if state.wants_tool_use {
- LlmStopReason::ToolUse
- } else {
- LlmStopReason::EndTurn
- }
- }
- "MAX_TOKENS" => LlmStopReason::MaxTokens,
- "SAFETY" => LlmStopReason::Refusal,
- _ => LlmStopReason::EndTurn,
- });
- }
+ let part: FunctionCallPart = serde_json::from_value(json).unwrap();
- if let Some(content) = candidate.content {
- for part in content.parts {
- match part {
- GooglePart::Text(text_part) => {
- if !text_part.text.is_empty() {
- return Ok(Some(LlmCompletionEvent::Text(
- text_part.text,
- )));
- }
- }
- GooglePart::FunctionCall(fc_part) => {
- state.wants_tool_use = true;
- let next_tool_id =
- TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
- let id = format!(
- "{}-{}",
- fc_part.function_call.name, next_tool_id
- );
-
- let thought_signature =
- fc_part.thought_signature.filter(|s| !s.is_empty());
-
- return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
- id,
- name: fc_part.function_call.name,
- input: fc_part.function_call.args.to_string(),
- is_input_complete: true,
- thought_signature,
- })));
- }
- GooglePart::Thought(thought_part) => {
- return Ok(Some(LlmCompletionEvent::Thinking(
- LlmThinkingContent {
- text: "(Encrypted thought)".to_string(),
- signature: Some(thought_part.thought_signature),
- },
- )));
- }
- _ => {}
- }
- }
- }
- }
-
- if let Some(usage) = response.usage_metadata {
- return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
- input_tokens: usage.prompt_token_count,
- output_tokens: usage.candidates_token_count,
- cache_creation_input_tokens: None,
- cache_read_input_tokens: None,
- })));
- }
- }
+ assert_eq!(part.function_call.name, "test_function");
+ assert_eq!(part.thought_signature, Some("test_signature".to_string()));
+ }
- continue;
+ #[test]
+ fn test_function_call_part_deserializes_without_signature() {
+ let json = json!({
+ "functionCall": {
+ "name": "test_function",
+ "args": {"arg": "value"}
}
+ });
- match response_stream.next_chunk() {
- Ok(Some(chunk)) => {
- let text = String::from_utf8_lossy(&chunk);
- state.buffer.push_str(&text);
- }
- Ok(None) => {
- // Stream ended - check if we have a stop reason
- if let Some(stop_reason) = state.stop_reason.take() {
- return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
- }
+ let part: FunctionCallPart = serde_json::from_value(json).unwrap();
- // No stop reason - this is unexpected. Check if buffer contains error info
- let mut error_msg = String::from("Stream ended unexpectedly.");
+ assert_eq!(part.function_call.name, "test_function");
+ assert_eq!(part.thought_signature, None);
+ }
- // Try to parse remaining buffer as potential error response
- if !state.buffer.is_empty() {
- error_msg.push_str(&format!(
- "\nRemaining buffer: {}",
- &state.buffer[..state.buffer.len().min(1000)]
- ));
- }
+ #[test]
+ fn test_function_call_part_round_trip() {
+ let original = FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value", "nested": {"key": "val"}}),
+ },
+ thought_signature: Some("round_trip_signature".to_string()),
+ };
- return Err(error_msg);
- }
- Err(e) => {
- return Err(format!("Stream error: {}", e));
- }
- }
- }
+ let serialized = serde_json::to_value(&original).unwrap();
+ let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
+
+ assert_eq!(deserialized.function_call.name, original.function_call.name);
+ assert_eq!(deserialized.function_call.args, original.function_call.args);
+ assert_eq!(deserialized.thought_signature, original.thought_signature);
}
- fn llm_stream_completion_close(&mut self, stream_id: &str) {
- self.streams.lock().unwrap().remove(stream_id);
+ #[test]
+ fn test_function_call_part_with_empty_signature_serializes() {
+ let part = FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
+ },
+ thought_signature: Some("".to_string()),
+ };
+
+ let serialized = serde_json::to_value(&part).unwrap();
+
+ // Empty string should still be serialized (normalization happens at a higher level)
+ assert_eq!(serialized["thoughtSignature"], "");
}
}
-
-zed::register_extension!(GoogleAiProvider);