diff --git a/extensions/google-ai/src/google_ai.rs b/extensions/google-ai/src/google_ai.rs index bad0bf43692bc6fbd4bbabe4de388d2a7f1b6fde..7fdd83e49e7234e8ec97730ddb85e8b607b6b698 100644 --- a/extensions/google-ai/src/google_ai.rs +++ b/extensions/google-ai/src/google_ai.rs @@ -1,136 +1,579 @@ +use std::collections::HashMap; -use std::mem; - -use anyhow::{Result, anyhow, bail}; -use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub use settings::ModelMode as GoogleModelMode; +use zed_extension_api::{ + self as zed, http_client::HttpMethod, http_client::HttpRequest, llm_get_env_var, + LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmMessageContent, + LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, LlmStopReason, + LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse, +}; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; -pub async fn stream_generate_content( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - mut request: GenerateContentRequest, -) -> Result>> { - let api_key = api_key.trim(); - validate_generate_content_request(&request)?; - - // The `model` field is emptied as it is provided as a path parameter. - let model_id = mem::take(&mut request.model.model_id); - - let uri = - format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",); - - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json"); - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; - let mut response = client.send(request).await?; - if response.status().is_success() { - let reader = BufReader::new(response.into_body()); - Ok(reader - .lines() - .filter_map(|line| async move { - match line { - Ok(line) => { - if let Some(line) = line.strip_prefix("data: ") { - match serde_json::from_str(line) { - Ok(response) => Some(Ok(response)), - Err(error) => Some(Err(anyhow!(format!( - "Error parsing JSON: {error:?}\n{line:?}" - )))), - } - } else { - None - } - } - Err(error) => Some(Err(anyhow!(error))), - } - }) - .boxed()) - } else { - let mut text = String::new(); - response.body_mut().read_to_string(&mut text).await?; - Err(anyhow!( - "error during streamGenerateContent, status code: {:?}, body: {}", - response.status(), - text - )) - } -} +fn stream_generate_content( + model_id: &str, + request: &LlmCompletionRequest, + streams: &mut HashMap, + next_stream_id: &mut u64, +) -> Result { + let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?; -pub async fn count_tokens( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CountTokensRequest, -) -> Result { - validate_generate_content_request(&request.generate_content_request)?; + let generate_content_request = build_generate_content_request(model_id, request)?; + validate_generate_content_request(&generate_content_request)?; let uri = format!( - "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", - model_id = &request.generate_content_request.model.model_id, + "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", + API_URL, model_id, api_key + ); + + let body = serde_json::to_vec(&generate_content_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let http_request = HttpRequest::builder() + .method(HttpMethod::Post) + .url(&uri) + .header("Content-Type", "application/json") + .body(body) + .build()?; + + let response_stream = http_request.fetch_stream()?; + + let stream_id = format!("stream-{}", *next_stream_id); + *next_stream_id += 1; + + streams.insert( + stream_id.clone(), + StreamState { + response_stream, + buffer: String::new(), + usage: None, + }, ); - let request = serde_json::to_string(&request)?; - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(&uri) - .header("Content-Type", "application/json"); - let http_request = request_builder.body(AsyncBody::from(request))?; - - let mut response = client.send(http_request).await?; - let mut text = String::new(); - response.body_mut().read_to_string(&mut text).await?; - anyhow::ensure!( - response.status().is_success(), - "error during countTokens, status code: {:?}, body: {}", - response.status(), - text + Ok(stream_id) +} + +fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result { + let api_key = get_api_key().ok_or_else(|| "API key not configured".to_string())?; + + let generate_content_request = build_generate_content_request(model_id, request)?; + validate_generate_content_request(&generate_content_request)?; + let count_request = CountTokensRequest { + generate_content_request, + }; + + let uri = format!( + "{}/v1beta/models/{}:countTokens?key={}", + API_URL, model_id, api_key ); - Ok(serde_json::from_str::(&text)?) + + let body = serde_json::to_vec(&count_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let http_request = HttpRequest::builder() + .method(HttpMethod::Post) + .url(&uri) + .header("Content-Type", "application/json") + .body(body) + .build()?; + + let response = http_request.fetch()?; + let response_body: CountTokensResponse = serde_json::from_slice(&response.body) + .map_err(|e| format!("Failed to parse response: {}", e))?; + + Ok(response_body.total_tokens) } -pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { +fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<(), String> { if request.model.is_empty() { - bail!("Model must be specified"); + return Err("Model must be specified".to_string()); } if request.contents.is_empty() { - bail!("Request must contain at least one content item"); + return Err("Request must contain at least one content item".to_string()); } if let Some(user_content) = request .contents .iter() .find(|content| content.role == Role::User) - && user_content.parts.is_empty() { - bail!("User content must contain at least one part"); + if user_content.parts.is_empty() { + return Err("User content must contain at least one part".to_string()); + } } Ok(()) } -#[derive(Debug, Serialize, Deserialize)] -pub enum Task { - #[serde(rename = "generateContent")] - GenerateContent, - #[serde(rename = "streamGenerateContent")] - StreamGenerateContent, - #[serde(rename = "countTokens")] - CountTokens, - #[serde(rename = "embedContent")] - EmbedContent, - #[serde(rename = "batchEmbedContents")] - BatchEmbedContents, +// Extension implementation + +const PROVIDER_ID: &str = "google-ai"; +const PROVIDER_NAME: &str = "Google AI"; + +struct GoogleAiExtension { + streams: HashMap, + next_stream_id: u64, +} + +struct StreamState { + response_stream: zed::http_client::HttpResponseStream, + buffer: String, + usage: Option, +} + +impl zed::Extension for GoogleAiExtension { + fn new() -> Self { + Self { + streams: HashMap::new(), + next_stream_id: 0, + } + } + + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: PROVIDER_ID.to_string(), + name: PROVIDER_NAME.to_string(), + icon: Some("icons/google-ai.svg".to_string()), + }] + } + + fn llm_provider_models(&self, provider_id: &str) -> Result, String> { + if provider_id != PROVIDER_ID { + return Err(format!("Unknown provider: {}", provider_id)); + } + Ok(get_models()) + } + + fn llm_provider_settings_markdown(&self, provider_id: &str) -> Option { + if provider_id != PROVIDER_ID { + return None; + } + + Some( + r#"## Google AI Setup + +To use Google AI models in Zed, you need a Gemini API key. + +1. Go to [Google AI Studio](https://aistudio.google.com/apikey) +2. Create or select a project +3. Generate an API key +4. Set the `GEMINI_API_KEY` or `GOOGLE_AI_API_KEY` environment variable + +You can set this in your shell profile or use a `.envrc` file with [direnv](https://direnv.net/). +"# + .to_string(), + ) + } + + fn llm_provider_is_authenticated(&self, provider_id: &str) -> bool { + if provider_id != PROVIDER_ID { + return false; + } + get_api_key().is_some() + } + + fn llm_provider_reset_credentials(&mut self, provider_id: &str) -> Result<(), String> { + if provider_id != PROVIDER_ID { + return Err(format!("Unknown provider: {}", provider_id)); + } + Ok(()) + } + + fn llm_count_tokens( + &self, + provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + if provider_id != PROVIDER_ID { + return Err(format!("Unknown provider: {}", provider_id)); + } + count_tokens(model_id, request) + } + + fn llm_stream_completion_start( + &mut self, + provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + if provider_id != PROVIDER_ID { + return Err(format!("Unknown provider: {}", provider_id)); + } + stream_generate_content(model_id, request, &mut self.streams, &mut self.next_stream_id) + } + + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + stream_generate_content_next(stream_id, &mut self.streams) + } + + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.remove(stream_id); + } + + fn llm_cache_configuration( + &self, + provider_id: &str, + _model_id: &str, + ) -> Option { + if provider_id != PROVIDER_ID { + return None; + } + + Some(LlmCacheConfiguration { + max_cache_anchors: 1, + should_cache_tool_definitions: false, + min_total_token_count: 32768, + }) + } +} + +zed::register_extension!(GoogleAiExtension); + +// Helper functions + +fn get_api_key() -> Option { + llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY")) +} + +fn get_models() -> Vec { + vec![ + LlmModelInfo { + id: "gemini-2.5-flash-lite".to_string(), + name: "Gemini 2.5 Flash-Lite".to_string(), + max_token_count: 1_048_576, + max_output_tokens: Some(65_536), + capabilities: LlmModelCapabilities { + supports_images: true, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: true, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: false, + is_default_fast: true, + }, + LlmModelInfo { + id: "gemini-2.5-flash".to_string(), + name: "Gemini 2.5 Flash".to_string(), + max_token_count: 1_048_576, + max_output_tokens: Some(65_536), + capabilities: LlmModelCapabilities { + supports_images: true, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: true, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: true, + is_default_fast: false, + }, + LlmModelInfo { + id: "gemini-2.5-pro".to_string(), + name: "Gemini 2.5 Pro".to_string(), + max_token_count: 1_048_576, + max_output_tokens: Some(65_536), + capabilities: LlmModelCapabilities { + supports_images: true, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: true, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: false, + is_default_fast: false, + }, + LlmModelInfo { + id: "gemini-3-pro-preview".to_string(), + name: "Gemini 3 Pro".to_string(), + max_token_count: 1_048_576, + max_output_tokens: Some(65_536), + capabilities: LlmModelCapabilities { + supports_images: true, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: true, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: false, + is_default_fast: false, + }, + LlmModelInfo { + id: "gemini-3-flash-preview".to_string(), + name: "Gemini 3 Flash".to_string(), + max_token_count: 1_048_576, + max_output_tokens: Some(65_536), + capabilities: LlmModelCapabilities { + supports_images: true, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: true, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: false, + is_default_fast: false, + }, + ] +} + +fn stream_generate_content_next( + stream_id: &str, + streams: &mut HashMap, +) -> Result, String> { + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if let Some(data) = line.strip_prefix("data: ") { + if data.trim().is_empty() { + continue; + } + + let response: GenerateContentResponse = serde_json::from_str(data) + .map_err(|e| format!("Failed to parse SSE data: {} - {}", e, data))?; + + if let Some(usage) = response.usage_metadata { + state.usage = Some(usage); + } + + if let Some(candidates) = response.candidates { + for candidate in candidates { + for part in candidate.content.parts { + match part { + Part::TextPart(text_part) => { + return Ok(Some(LlmCompletionEvent::Text(text_part.text))); + } + Part::ThoughtPart(thought_part) => { + return Ok(Some(LlmCompletionEvent::Thinking( + LlmThinkingContent { + text: String::new(), + signature: Some(thought_part.thought_signature), + }, + ))); + } + Part::FunctionCallPart(fc_part) => { + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id: fc_part.function_call.name.clone(), + name: fc_part.function_call.name, + input: serde_json::to_string(&fc_part.function_call.args) + .unwrap_or_default(), + is_input_complete: true, + thought_signature: fc_part.thought_signature, + }))); + } + _ => {} + } + } + + if let Some(finish_reason) = candidate.finish_reason { + let stop_reason = match finish_reason.as_str() { + "STOP" => LlmStopReason::EndTurn, + "MAX_TOKENS" => LlmStopReason::MaxTokens, + "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse, + "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + }; + + if let Some(usage) = state.usage.take() { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_token_count.unwrap_or(0), + output_tokens: usage.candidates_token_count.unwrap_or(0), + cache_creation_input_tokens: None, + cache_read_input_tokens: usage.cached_content_token_count, + }))); + } + + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + } + } + } + + continue; + } + + match state.response_stream.next_chunk() { + Ok(Some(chunk)) => { + let chunk_str = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&chunk_str); + } + Ok(None) => { + streams.remove(stream_id); + return Ok(None); + } + Err(e) => { + streams.remove(stream_id); + return Err(e); + } + } + } +} + +fn build_generate_content_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result { + let mut contents: Vec = Vec::new(); + let mut system_instruction: Option = None; + + for message in &request.messages { + match message.role { + LlmMessageRole::System => { + let parts = convert_content_to_parts(&message.content)?; + system_instruction = Some(SystemInstruction { parts }); + } + LlmMessageRole::User | LlmMessageRole::Assistant => { + let role = match message.role { + LlmMessageRole::User => Role::User, + LlmMessageRole::Assistant => Role::Model, + _ => continue, + }; + let parts = convert_content_to_parts(&message.content)?; + contents.push(Content { parts, role }); + } + } + } + + let tools = if !request.tools.is_empty() { + Some(vec![Tool { + function_declarations: request + .tools + .iter() + .map(|t| FunctionDeclaration { + name: t.name.clone(), + description: t.description.clone(), + parameters: serde_json::from_str(&t.input_schema).unwrap_or_default(), + }) + .collect(), + }]) + } else { + None + }; + + let tool_config = request.tool_choice.as_ref().map(|choice| { + let mode = match choice { + zed::LlmToolChoice::Auto => FunctionCallingMode::Auto, + zed::LlmToolChoice::Any => FunctionCallingMode::Any, + zed::LlmToolChoice::None => FunctionCallingMode::None, + }; + ToolConfig { + function_calling_config: FunctionCallingConfig { + mode, + allowed_function_names: None, + }, + } + }); + + let generation_config = Some(GenerationConfig { + candidate_count: Some(1), + stop_sequences: if request.stop_sequences.is_empty() { + None + } else { + Some(request.stop_sequences.clone()) + }, + max_output_tokens: request.max_tokens.map(|t| t as usize), + temperature: request.temperature.map(|t| t as f64), + top_p: None, + top_k: None, + thinking_config: if request.thinking_allowed { + Some(ThinkingConfig { + thinking_budget: 8192, + }) + } else { + None + }, + }); + + Ok(GenerateContentRequest { + model: ModelName { + model_id: model_id.to_string(), + }, + contents, + system_instruction, + generation_config, + safety_settings: None, + tools, + tool_config, + }) +} + +fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result, String> { + let mut parts = Vec::new(); + + for item in content { + match item { + LlmMessageContent::Text(text) => { + parts.push(Part::TextPart(TextPart { text: text.clone() })); + } + LlmMessageContent::Image(image) => { + parts.push(Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.clone(), + }, + })); + } + LlmMessageContent::ToolUse(tool_use) => { + parts.push(Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: tool_use.name.clone(), + args: serde_json::from_str(&tool_use.input).unwrap_or_default(), + }, + thought_signature: tool_use.thought_signature.clone(), + })); + } + LlmMessageContent::ToolResult(tool_result) => { + let response_value = match &tool_result.content { + zed::LlmToolResultContent::Text(text) => { + serde_json::json!({ "result": text }) + } + zed::LlmToolResultContent::Image(_) => { + serde_json::json!({ "error": "Image results not supported" }) + } + }; + parts.push(Part::FunctionResponsePart(FunctionResponsePart { + function_response: FunctionResponse { + name: tool_result.tool_name.clone(), + response: response_value, + }, + })); + } + LlmMessageContent::Thinking(thinking) => { + if let Some(signature) = &thinking.signature { + parts.push(Part::ThoughtPart(ThoughtPart { + thought: true, + thought_signature: signature.clone(), + })); + } + } + LlmMessageContent::RedactedThinking(_) => {} + } + } + + Ok(parts) } +// Data structures for Google AI API + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { @@ -481,238 +924,3 @@ impl<'de> Deserialize<'de> for ModelName { } } } - -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] -pub enum Model { - #[serde( - rename = "gemini-2.5-flash-lite", - alias = "gemini-2.5-flash-lite-preview-06-17", - alias = "gemini-2.0-flash-lite-preview" - )] - Gemini25FlashLite, - #[serde( - rename = "gemini-2.5-flash", - alias = "gemini-2.0-flash-thinking-exp", - alias = "gemini-2.5-flash-preview-04-17", - alias = "gemini-2.5-flash-preview-05-20", - alias = "gemini-2.5-flash-preview-latest", - alias = "gemini-2.0-flash" - )] - #[default] - Gemini25Flash, - #[serde( - rename = "gemini-2.5-pro", - alias = "gemini-2.0-pro-exp", - alias = "gemini-2.5-pro-preview-latest", - alias = "gemini-2.5-pro-exp-03-25", - alias = "gemini-2.5-pro-preview-03-25", - alias = "gemini-2.5-pro-preview-05-06", - alias = "gemini-2.5-pro-preview-06-05" - )] - Gemini25Pro, - #[serde(rename = "gemini-3-pro-preview")] - Gemini3Pro, - #[serde(rename = "custom")] - Custom { - name: String, - /// The name displayed in the UI, such as in the assistant panel model dropdown menu. - display_name: Option, - max_tokens: u64, - #[serde(default)] - mode: GoogleModelMode, - }, -} - -impl Model { - pub fn default_fast() -> Self { - Self::Gemini25FlashLite - } - - pub fn id(&self) -> &str { - match self { - Self::Gemini25FlashLite => "gemini-2.5-flash-lite", - Self::Gemini25Flash => "gemini-2.5-flash", - Self::Gemini25Pro => "gemini-2.5-pro", - Self::Gemini3Pro => "gemini-3-pro-preview", - Self::Custom { name, .. } => name, - } - } - pub fn request_id(&self) -> &str { - match self { - Self::Gemini25FlashLite => "gemini-2.5-flash-lite", - Self::Gemini25Flash => "gemini-2.5-flash", - Self::Gemini25Pro => "gemini-2.5-pro", - Self::Gemini3Pro => "gemini-3-pro-preview", - Self::Custom { name, .. } => name, - } - } - - pub fn display_name(&self) -> &str { - match self { - Self::Gemini25FlashLite => "Gemini 2.5 Flash-Lite", - Self::Gemini25Flash => "Gemini 2.5 Flash", - Self::Gemini25Pro => "Gemini 2.5 Pro", - Self::Gemini3Pro => "Gemini 3 Pro", - Self::Custom { - name, display_name, .. - } => display_name.as_ref().unwrap_or(name), - } - } - - pub fn max_token_count(&self) -> u64 { - match self { - Self::Gemini25FlashLite => 1_048_576, - Self::Gemini25Flash => 1_048_576, - Self::Gemini25Pro => 1_048_576, - Self::Gemini3Pro => 1_048_576, - Self::Custom { max_tokens, .. } => *max_tokens, - } - } - - pub fn max_output_tokens(&self) -> Option { - match self { - Model::Gemini25FlashLite => Some(65_536), - Model::Gemini25Flash => Some(65_536), - Model::Gemini25Pro => Some(65_536), - Model::Gemini3Pro => Some(65_536), - Model::Custom { .. } => None, - } - } - - pub fn supports_tools(&self) -> bool { - true - } - - pub fn supports_images(&self) -> bool { - true - } - - pub fn mode(&self) -> GoogleModelMode { - match self { - Self::Gemini25FlashLite - | Self::Gemini25Flash - | Self::Gemini25Pro - | Self::Gemini3Pro => { - GoogleModelMode::Thinking { - // By default these models are set to "auto", so we preserve that behavior - // but indicate they are capable of thinking mode - budget_tokens: None, - } - } - Self::Custom { mode, .. } => *mode, - } - } -} - -impl std::fmt::Display for Model { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.id()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_function_call_part_with_signature_serializes_correctly() { - let part = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("test_signature".to_string()), - }; - - let serialized = serde_json::to_value(&part).unwrap(); - - assert_eq!(serialized["functionCall"]["name"], "test_function"); - assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); - assert_eq!(serialized["thoughtSignature"], "test_signature"); - } - - #[test] - fn test_function_call_part_without_signature_omits_field() { - let part = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: None, - }; - - let serialized = serde_json::to_value(&part).unwrap(); - - assert_eq!(serialized["functionCall"]["name"], "test_function"); - assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); - // thoughtSignature field should not be present when None - assert!(serialized.get("thoughtSignature").is_none()); - } - - #[test] - fn test_function_call_part_deserializes_with_signature() { - let json = json!({ - "functionCall": { - "name": "test_function", - "args": {"arg": "value"} - }, - "thoughtSignature": "test_signature" - }); - - let part: FunctionCallPart = serde_json::from_value(json).unwrap(); - - assert_eq!(part.function_call.name, "test_function"); - assert_eq!(part.thought_signature, Some("test_signature".to_string())); - } - - #[test] - fn test_function_call_part_deserializes_without_signature() { - let json = json!({ - "functionCall": { - "name": "test_function", - "args": {"arg": "value"} - } - }); - - let part: FunctionCallPart = serde_json::from_value(json).unwrap(); - - assert_eq!(part.function_call.name, "test_function"); - assert_eq!(part.thought_signature, None); - } - - #[test] - fn test_function_call_part_round_trip() { - let original = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value", "nested": {"key": "val"}}), - }, - thought_signature: Some("round_trip_signature".to_string()), - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.function_call.name, original.function_call.name); - assert_eq!(deserialized.function_call.args, original.function_call.args); - assert_eq!(deserialized.thought_signature, original.thought_signature); - } - - #[test] - fn test_function_call_part_with_empty_signature_serializes() { - let part = FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("".to_string()), - }; - - let serialized = serde_json::to_value(&part).unwrap(); - - // Empty string should still be serialized (normalization happens at a higher level) - assert_eq!(serialized["thoughtSignature"], ""); - } -}