@@ -1,136 +1,579 @@
+use std::collections::HashMap;
-use std::mem;
-
-use anyhow::{Result, anyhow, bail};
-use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
-pub use settings::ModelMode as GoogleModelMode;
+use zed_extension_api::{
+ self as zed, http_client::HttpMethod, http_client::HttpRequest, llm_get_env_var,
+ LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmMessageContent,
+ LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, LlmStopReason,
+ LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse,
+};
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
-pub async fn stream_generate_content(
- client: &dyn HttpClient,
- api_url: &str,
- api_key: &str,
- mut request: GenerateContentRequest,
-) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
- let api_key = api_key.trim();
- validate_generate_content_request(&request)?;
-
- // The `model` field is emptied as it is provided as a path parameter.
- let model_id = mem::take(&mut request.model.model_id);
-
- let uri =
- format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
-
- let request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(uri)
- .header("Content-Type", "application/json");
-
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
- let mut response = client.send(request).await?;
- if response.status().is_success() {
- let reader = BufReader::new(response.into_body());
- Ok(reader
- .lines()
- .filter_map(|line| async move {
- match line {
- Ok(line) => {
- if let Some(line) = line.strip_prefix("data: ") {
- match serde_json::from_str(line) {
- Ok(response) => Some(Ok(response)),
- Err(error) => Some(Err(anyhow!(format!(
- "Error parsing JSON: {error:?}\n{line:?}"
- )))),
- }
- } else {
- None
- }
- }
- Err(error) => Some(Err(anyhow!(error))),
- }
- })
- .boxed())
- } else {
- let mut text = String::new();
- response.body_mut().read_to_string(&mut text).await?;
- Err(anyhow!(
- "error during streamGenerateContent, status code: {:?}, body: {}",
- response.status(),
- text
- ))
- }
-}
+fn stream_generate_content(
+ model_id: &str,
+ request: &LlmCompletionRequest,
+ streams: &mut HashMap<String, StreamState>,
+ next_stream_id: &mut u64,
+) -> Result<String, String> {
+ let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
-pub async fn count_tokens(
- client: &dyn HttpClient,
- api_url: &str,
- api_key: &str,
- request: CountTokensRequest,
-) -> Result<CountTokensResponse> {
- validate_generate_content_request(&request.generate_content_request)?;
+ let generate_content_request = build_generate_content_request(model_id, request)?;
+ validate_generate_content_request(&generate_content_request)?;
let uri = format!(
- "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
- model_id = &request.generate_content_request.model.model_id,
+ "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
+ API_URL, model_id, api_key
+ );
+
+ let body = serde_json::to_vec(&generate_content_request)
+ .map_err(|e| format!("Failed to serialize request: {}", e))?;
+
+ let http_request = HttpRequest::builder()
+ .method(HttpMethod::Post)
+ .url(&uri)
+ .header("Content-Type", "application/json")
+ .body(body)
+ .build()?;
+
+ let response_stream = http_request.fetch_stream()?;
+
+ let stream_id = format!("stream-{}", *next_stream_id);
+ *next_stream_id += 1;
+
+ streams.insert(
+ stream_id.clone(),
+ StreamState {
+ response_stream,
+ buffer: String::new(),
+ usage: None,
+ },
);
- let request = serde_json::to_string(&request)?;
- let request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(&uri)
- .header("Content-Type", "application/json");
- let http_request = request_builder.body(AsyncBody::from(request))?;
-
- let mut response = client.send(http_request).await?;
- let mut text = String::new();
- response.body_mut().read_to_string(&mut text).await?;
- anyhow::ensure!(
- response.status().is_success(),
- "error during countTokens, status code: {:?}, body: {}",
- response.status(),
- text
+ Ok(stream_id)
+}
+
+fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result<u64, String> {
+ let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?;
+
+ let generate_content_request = build_generate_content_request(model_id, request)?;
+ validate_generate_content_request(&generate_content_request)?;
+ let count_request = CountTokensRequest {
+ generate_content_request,
+ };
+
+ let uri = format!(
+ "{}/v1beta/models/{}:countTokens?key={}",
+ API_URL, model_id, api_key
);
- Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+
+ let body = serde_json::to_vec(&count_request)
+ .map_err(|e| format!("Failed to serialize request: {}", e))?;
+
+ let http_request = HttpRequest::builder()
+ .method(HttpMethod::Post)
+ .url(&uri)
+ .header("Content-Type", "application/json")
+ .body(body)
+ .build()?;
+
+ let response = http_request.fetch()?;
+ let response_body: CountTokensResponse = serde_json::from_slice(&response.body)
+ .map_err(|e| format!("Failed to parse response: {}", e))?;
+
+ Ok(response_body.total_tokens)
}
-pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
+fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> {
if request.model.is_empty() {
- bail!("Model must be specified");
+ return Err("Model must be specified".to_string());
}
if request.contents.is_empty() {
- bail!("Request must contain at least one content item");
+ return Err("Request must contain at least one content item".to_string());
}
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
- && user_content.parts.is_empty()
{
- bail!("User content must contain at least one part");
+ if user_content.parts.is_empty() {
+ return Err("User content must contain at least one part".to_string());
+ }
}
Ok(())
}
-#[derive(Debug, Serialize, Deserialize)]
-pub enum Task {
- #[serde(rename = "generateContent")]
- GenerateContent,
- #[serde(rename = "streamGenerateContent")]
- StreamGenerateContent,
- #[serde(rename = "countTokens")]
- CountTokens,
- #[serde(rename = "embedContent")]
- EmbedContent,
- #[serde(rename = "batchEmbedContents")]
- BatchEmbedContents,
+// Extension implementation
+
+const PROVIDER_ID: &str = "google-ai";
+const PROVIDER_NAME: &str = "Google AI";
+
+struct GoogleAiExtension {
+ streams: HashMap<String, StreamState>,
+ next_stream_id: u64,
+}
+
+struct StreamState {
+ response_stream: zed::http_client::HttpResponseStream,
+ buffer: String,
+ usage: Option<UsageMetadata>,
+}
+
+impl zed::Extension for GoogleAiExtension {
+ fn new() -> Self {
+ Self {
+ streams: HashMap::new(),
+ next_stream_id: 0,
+ }
+ }
+
+ fn llm_providers(&self) -> Vec<LlmProviderInfo> {
+ vec![LlmProviderInfo {
+ id: PROVIDER_ID.to_string(),
+ name: PROVIDER_NAME.to_string(),
+ icon: Some("icons/google-ai.svg".to_string()),
+ }]
+ }
+
+ fn llm_provider_models(&self, provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
+ if provider_id != PROVIDER_ID {
+ return Err(format!("Unknown provider: {}", provider_id));
+ }
+ Ok(get_models())
+ }
+
+ fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option<String> {
+ if provider_id != PROVIDER_ID {
+ return None;
+ }
+
+ Some(
+ r#"## Google AI Setup
+
+To use Google AI models in Zed, you need a Gemini API key.
+
+1. Go to [Google AI Studio](https://aistudio.google.com/apikey)
+2. Create or select a project
+3. Generate an API key
+4. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable
+
+You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/).
+"#
+ .to_string(),
+ )
+ }
+
+ fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool {
+ if provider_id != PROVIDER_ID {
+ return false;
+ }
+ get_api_key().is_some()
+ }
+
+ fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> {
+ if provider_id != PROVIDER_ID {
+ return Err(format!("Unknown provider: {}", provider_id));
+ }
+ Ok(())
+ }
+
+ fn llm_count_tokens(
+ &self,
+ provider_id: &str,
+ model_id: &str,
+ request: &LlmCompletionRequest,
+ ) -> Result<u64, String> {
+ if provider_id != PROVIDER_ID {
+ return Err(format!("Unknown provider: {}", provider_id));
+ }
+ count_tokens(model_id, request)
+ }
+
+ fn llm_stream_completion_start(
+ &mut self,
+ provider_id: &str,
+ model_id: &str,
+ request: &LlmCompletionRequest,
+ ) -> Result<String, String> {
+ if provider_id != PROVIDER_ID {
+ return Err(format!("Unknown provider: {}", provider_id));
+ }
+ stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id)
+ }
+
+ fn llm_stream_completion_next(
+ &mut self,
+ stream_id: &str,
+ ) -> Result<Option<LlmCompletionEvent>, String> {
+ stream_generate_content_next(stream_id, &mut self.streams)
+ }
+
+ fn llm_stream_completion_close(&mut self, stream_id: &str) {
+ self.streams.remove(stream_id);
+ }
+
+ fn llm_cache_configuration(
+ &self,
+ provider_id: &str,
+ _model_id: &str,
+ ) -> Option<LlmCacheConfiguration> {
+ if provider_id != PROVIDER_ID {
+ return None;
+ }
+
+ Some(LlmCacheConfiguration {
+ max_cache_anchors: 1,
+ should_cache_tool_definitions: false,
+ min_total_token_count: 32768,
+ })
+ }
+}
+
+zed::register_extension!(GoogleAiExtension);
+
+// Helper functions
+
+fn get_api_key() -> Option<String> {
+ llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY"))
+}
+
+fn get_models() -> Vec<LlmModelInfo> {
+ vec![
+ LlmModelInfo {
+ id: "gemini-2.5-flash-lite".to_string(),
+ name: "Gemini 2.5 Flash-Lite".to_string(),
+ max_token_count: 1_048_576,
+ max_output_tokens: Some(65_536),
+ capabilities: LlmModelCapabilities {
+ supports_images: true,
+ supports_tools: true,
+ supports_tool_choice_auto: true,
+ supports_tool_choice_any: true,
+ supports_tool_choice_none: true,
+ supports_thinking: true,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: false,
+ is_default_fast: true,
+ },
+ LlmModelInfo {
+ id: "gemini-2.5-flash".to_string(),
+ name: "Gemini 2.5 Flash".to_string(),
+ max_token_count: 1_048_576,
+ max_output_tokens: Some(65_536),
+ capabilities: LlmModelCapabilities {
+ supports_images: true,
+ supports_tools: true,
+ supports_tool_choice_auto: true,
+ supports_tool_choice_any: true,
+ supports_tool_choice_none: true,
+ supports_thinking: true,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: true,
+ is_default_fast: false,
+ },
+ LlmModelInfo {
+ id: "gemini-2.5-pro".to_string(),
+ name: "Gemini 2.5 Pro".to_string(),
+ max_token_count: 1_048_576,
+ max_output_tokens: Some(65_536),
+ capabilities: LlmModelCapabilities {
+ supports_images: true,
+ supports_tools: true,
+ supports_tool_choice_auto: true,
+ supports_tool_choice_any: true,
+ supports_tool_choice_none: true,
+ supports_thinking: true,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: false,
+ is_default_fast: false,
+ },
+ LlmModelInfo {
+ id: "gemini-3-pro-preview".to_string(),
+ name: "Gemini 3 Pro".to_string(),
+ max_token_count: 1_048_576,
+ max_output_tokens: Some(65_536),
+ capabilities: LlmModelCapabilities {
+ supports_images: true,
+ supports_tools: true,
+ supports_tool_choice_auto: true,
+ supports_tool_choice_any: true,
+ supports_tool_choice_none: true,
+ supports_thinking: true,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: false,
+ is_default_fast: false,
+ },
+ LlmModelInfo {
+ id: "gemini-3-flash-preview".to_string(),
+ name: "Gemini 3 Flash".to_string(),
+ max_token_count: 1_048_576,
+ max_output_tokens: Some(65_536),
+ capabilities: LlmModelCapabilities {
+ supports_images: true,
+ supports_tools: true,
+ supports_tool_choice_auto: true,
+ supports_tool_choice_any: true,
+ supports_tool_choice_none: true,
+ supports_thinking: true,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: false,
+ is_default_fast: false,
+ },
+ ]
+}
+
+fn stream_generate_content_next(
+ stream_id: &str,
+ streams: &mut HashMap<String, StreamState>,
+) -> Result<Option<LlmCompletionEvent>, String> {
+ let state = streams
+ .get_mut(stream_id)
+ .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
+
+ loop {
+ if let Some(newline_pos) = state.buffer.find('\n') {
+ let line = state.buffer[..newline_pos].to_string();
+ state.buffer = state.buffer[newline_pos + 1..].to_string();
+
+ if let Some(data) = line.strip_prefix("data: ") {
+ if data.trim().is_empty() {
+ continue;
+ }
+
+ let response: GenerateContentResponse = serde_json::from_str(data)
+ .map_err(|e| format!("Failed to parse SSE data: {} - {}", e, data))?;
+
+ if let Some(usage) = response.usage_metadata {
+ state.usage = Some(usage);
+ }
+
+ if let Some(candidates) = response.candidates {
+ for candidate in candidates {
+ for part in candidate.content.parts {
+ match part {
+ Part::TextPart(text_part) => {
+ return Ok(Some(LlmCompletionEvent::Text(text_part.text)));
+ }
+ Part::ThoughtPart(thought_part) => {
+ return Ok(Some(LlmCompletionEvent::Thinking(
+ LlmThinkingContent {
+ text: String::new(),
+ signature: Some(thought_part.thought_signature),
+ },
+ )));
+ }
+ Part::FunctionCallPart(fc_part) => {
+ return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
+ id: fc_part.function_call.name.clone(),
+ name: fc_part.function_call.name,
+ input: serde_json::to_string(&fc_part.function_call.args)
+ .unwrap_or_default(),
+ is_input_complete: true,
+ thought_signature: fc_part.thought_signature,
+ })));
+ }
+ _ => {}
+ }
+ }
+
+ if let Some(finish_reason) = candidate.finish_reason {
+ let stop_reason = match finish_reason.as_str() {
+ "STOP" => LlmStopReason::EndTurn,
+ "MAX_TOKENS" => LlmStopReason::MaxTokens,
+ "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse,
+ "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal,
+ _ => LlmStopReason::EndTurn,
+ };
+
+ if let Some(usage) = state.usage.take() {
+ return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
+ input_tokens: usage.prompt_token_count.unwrap_or(0),
+ output_tokens: usage.candidates_token_count.unwrap_or(0),
+ cache_creation_input_tokens: None,
+ cache_read_input_tokens: usage.cached_content_token_count,
+ })));
+ }
+
+ return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
+ }
+ }
+ }
+ }
+
+ continue;
+ }
+
+ match state.response_stream.next_chunk() {
+ Ok(Some(chunk)) => {
+ let chunk_str = String::from_utf8_lossy(&chunk);
+ state.buffer.push_str(&chunk_str);
+ }
+ Ok(None) => {
+ streams.remove(stream_id);
+ return Ok(None);
+ }
+ Err(e) => {
+ streams.remove(stream_id);
+ return Err(e);
+ }
+ }
+ }
+}
+
+fn build_generate_content_request(
+ model_id: &str,
+ request: &LlmCompletionRequest,
+) -> Result<GenerateContentRequest, String> {
+ let mut contents: Vec<Content> = Vec::new();
+ let mut system_instruction: Option<SystemInstruction> = None;
+
+ for message in &request.messages {
+ match message.role {
+ LlmMessageRole::System => {
+ let parts = convert_content_to_parts(&message.content)?;
+ system_instruction = Some(SystemInstruction { parts });
+ }
+ LlmMessageRole::User | LlmMessageRole::Assistant => {
+ let role = match message.role {
+ LlmMessageRole::User => Role::User,
+ LlmMessageRole::Assistant => Role::Model,
+ _ => continue,
+ };
+ let parts = convert_content_to_parts(&message.content)?;
+ contents.push(Content { parts, role });
+ }
+ }
+ }
+
+ let tools = if !request.tools.is_empty() {
+ Some(vec![Tool {
+ function_declarations: request
+ .tools
+ .iter()
+ .map(|t| FunctionDeclaration {
+ name: t.name.clone(),
+ description: t.description.clone(),
+ parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(),
+ })
+ .collect(),
+ }])
+ } else {
+ None
+ };
+
+ let tool_config = request.tool_choice.as_ref().map(|choice| {
+ let mode = match choice {
+ zed::LlmToolChoice::Auto => FunctionCallingMode::Auto,
+ zed::LlmToolChoice::Any => FunctionCallingMode::Any,
+ zed::LlmToolChoice::None => FunctionCallingMode::None,
+ };
+ ToolConfig {
+ function_calling_config: FunctionCallingConfig {
+ mode,
+ allowed_function_names: None,
+ },
+ }
+ });
+
+ let generation_config = Some(GenerationConfig {
+ candidate_count: Some(1),
+ stop_sequences: if request.stop_sequences.is_empty() {
+ None
+ } else {
+ Some(request.stop_sequences.clone())
+ },
+ max_output_tokens: request.max_tokens.map(|t| t as usize),
+ temperature: request.temperature.map(|t| t as f64),
+ top_p: None,
+ top_k: None,
+ thinking_config: if request.thinking_allowed {
+ Some(ThinkingConfig {
+ thinking_budget: 8192,
+ })
+ } else {
+ None
+ },
+ });
+
+ Ok(GenerateContentRequest {
+ model: ModelName {
+ model_id: model_id.to_string(),
+ },
+ contents,
+ system_instruction,
+ generation_config,
+ safety_settings: None,
+ tools,
+ tool_config,
+ })
+}
+
+fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result<Vec<Part>, String> {
+ let mut parts = Vec::new();
+
+ for item in content {
+ match item {
+ LlmMessageContent::Text(text) => {
+ parts.push(Part::TextPart(TextPart { text: text.clone() }));
+ }
+ LlmMessageContent::Image(image) => {
+ parts.push(Part::InlineDataPart(InlineDataPart {
+ inline_data: GenerativeContentBlob {
+ mime_type: "image/png".to_string(),
+ data: image.source.clone(),
+ },
+ }));
+ }
+ LlmMessageContent::ToolUse(tool_use) => {
+ parts.push(Part::FunctionCallPart(FunctionCallPart {
+ function_call: FunctionCall {
+ name: tool_use.name.clone(),
+ args: serde_json::from_str(&tool_use.input).unwrap_or_default(),
+ },
+ thought_signature: tool_use.thought_signature.clone(),
+ }));
+ }
+ LlmMessageContent::ToolResult(tool_result) => {
+ let response_value = match &tool_result.content {
+ zed::LlmToolResultContent::Text(text) => {
+ serde_json::json!({ "result": text })
+ }
+ zed::LlmToolResultContent::Image(_) => {
+ serde_json::json!({ "error": "Image results not supported" })
+ }
+ };
+ parts.push(Part::FunctionResponsePart(FunctionResponsePart {
+ function_response: FunctionResponse {
+ name: tool_result.tool_name.clone(),
+ response: response_value,
+ },
+ }));
+ }
+ LlmMessageContent::Thinking(thinking) => {
+ if let Some(signature) = &thinking.signature {
+ parts.push(Part::ThoughtPart(ThoughtPart {
+ thought: true,
+ thought_signature: signature.clone(),
+ }));
+ }
+ }
+ LlmMessageContent::RedactedThinking(_) => {}
+ }
+ }
+
+ Ok(parts)
}
+// Data structures for Google AI API
+
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
@@ -481,238 +924,3 @@ impl<'de> Deserialize<'de> for ModelName {
}
}
}
-
-#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
-#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
-pub enum Model {
- #[serde(
- rename = "gemini-2.5-flash-lite",
- alias = "gemini-2.5-flash-lite-preview-06-17",
- alias = "gemini-2.0-flash-lite-preview"
- )]
- Gemini25FlashLite,
- #[serde(
- rename = "gemini-2.5-flash",
- alias = "gemini-2.0-flash-thinking-exp",
- alias = "gemini-2.5-flash-preview-04-17",
- alias = "gemini-2.5-flash-preview-05-20",
- alias = "gemini-2.5-flash-preview-latest",
- alias = "gemini-2.0-flash"
- )]
- #[default]
- Gemini25Flash,
- #[serde(
- rename = "gemini-2.5-pro",
- alias = "gemini-2.0-pro-exp",
- alias = "gemini-2.5-pro-preview-latest",
- alias = "gemini-2.5-pro-exp-03-25",
- alias = "gemini-2.5-pro-preview-03-25",
- alias = "gemini-2.5-pro-preview-05-06",
- alias = "gemini-2.5-pro-preview-06-05"
- )]
- Gemini25Pro,
- #[serde(rename = "gemini-3-pro-preview")]
- Gemini3Pro,
- #[serde(rename = "custom")]
- Custom {
- name: String,
- /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
- display_name: Option<String>,
- max_tokens: u64,
- #[serde(default)]
- mode: GoogleModelMode,
- },
-}
-
-impl Model {
- pub fn default_fast() -> Self {
- Self::Gemini25FlashLite
- }
-
- pub fn id(&self) -> &str {
- match self {
- Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
- Self::Gemini25Flash => "gemini-2.5-flash",
- Self::Gemini25Pro => "gemini-2.5-pro",
- Self::Gemini3Pro => "gemini-3-pro-preview",
- Self::Custom { name, .. } => name,
- }
- }
- pub fn request_id(&self) -> &str {
- match self {
- Self::Gemini25FlashLite => "gemini-2.5-flash-lite",
- Self::Gemini25Flash => "gemini-2.5-flash",
- Self::Gemini25Pro => "gemini-2.5-pro",
- Self::Gemini3Pro => "gemini-3-pro-preview",
- Self::Custom { name, .. } => name,
- }
- }
-
- pub fn display_name(&self) -> &str {
- match self {
- Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite",
- Self::Gemini25Flash => "Gemini 2.5 Flash",
- Self::Gemini25Pro => "Gemini 2.5 Pro",
- Self::Gemini3Pro => "Gemini 3 Pro",
- Self::Custom {
- name, display_name, ..
- } => display_name.as_ref().unwrap_or(name),
- }
- }
-
- pub fn max_token_count(&self) -> u64 {
- match self {
- Self::Gemini25FlashLite => 1_048_576,
- Self::Gemini25Flash => 1_048_576,
- Self::Gemini25Pro => 1_048_576,
- Self::Gemini3Pro => 1_048_576,
- Self::Custom { max_tokens, .. } => *max_tokens,
- }
- }
-
- pub fn max_output_tokens(&self) -> Option<u64> {
- match self {
- Model::Gemini25FlashLite => Some(65_536),
- Model::Gemini25Flash => Some(65_536),
- Model::Gemini25Pro => Some(65_536),
- Model::Gemini3Pro => Some(65_536),
- Model::Custom { .. } => None,
- }
- }
-
- pub fn supports_tools(&self) -> bool {
- true
- }
-
- pub fn supports_images(&self) -> bool {
- true
- }
-
- pub fn mode(&self) -> GoogleModelMode {
- match self {
- Self::Gemini25FlashLite
- | Self::Gemini25Flash
- | Self::Gemini25Pro
- | Self::Gemini3Pro => {
- GoogleModelMode::Thinking {
- // By default these models are set to "auto", so we preserve that behavior
- // but indicate they are capable of thinking mode
- budget_tokens: None,
- }
- }
- Self::Custom { mode, .. } => *mode,
- }
- }
-}
-
-impl std::fmt::Display for Model {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.id())
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use serde_json::json;
-
- #[test]
- fn test_function_call_part_with_signature_serializes_correctly() {
- let part = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("test_signature".to_string()),
- };
-
- let serialized = serde_json::to_value(&part).unwrap();
-
- assert_eq!(serialized["functionCall"]["name"], "test_function");
- assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
- assert_eq!(serialized["thoughtSignature"], "test_signature");
- }
-
- #[test]
- fn test_function_call_part_without_signature_omits_field() {
- let part = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: None,
- };
-
- let serialized = serde_json::to_value(&part).unwrap();
-
- assert_eq!(serialized["functionCall"]["name"], "test_function");
- assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
- // thoughtSignature field should not be present when None
- assert!(serialized.get("thoughtSignature").is_none());
- }
-
- #[test]
- fn test_function_call_part_deserializes_with_signature() {
- let json = json!({
- "functionCall": {
- "name": "test_function",
- "args": {"arg": "value"}
- },
- "thoughtSignature": "test_signature"
- });
-
- let part: FunctionCallPart = serde_json::from_value(json).unwrap();
-
- assert_eq!(part.function_call.name, "test_function");
- assert_eq!(part.thought_signature, Some("test_signature".to_string()));
- }
-
- #[test]
- fn test_function_call_part_deserializes_without_signature() {
- let json = json!({
- "functionCall": {
- "name": "test_function",
- "args": {"arg": "value"}
- }
- });
-
- let part: FunctionCallPart = serde_json::from_value(json).unwrap();
-
- assert_eq!(part.function_call.name, "test_function");
- assert_eq!(part.thought_signature, None);
- }
-
- #[test]
- fn test_function_call_part_round_trip() {
- let original = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value", "nested": {"key": "val"}}),
- },
- thought_signature: Some("round_trip_signature".to_string()),
- };
-
- let serialized = serde_json::to_value(&original).unwrap();
- let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
-
- assert_eq!(deserialized.function_call.name, original.function_call.name);
- assert_eq!(deserialized.function_call.args, original.function_call.args);
- assert_eq!(deserialized.thought_signature, original.thought_signature);
- }
-
- #[test]
- fn test_function_call_part_with_empty_signature_serializes() {
- let part = FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("".to_string()),
- };
-
- let serialized = serde_json::to_value(&part).unwrap();
-
- // Empty string should still be serialized (normalization happens at a higher level)
- assert_eq!(serialized["thoughtSignature"], "");
- }
-}