From a51b99216d4265c1e3e7d8099ec266758af29d33 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Wed, 17 Dec 2025 15:14:45 -0500 Subject: [PATCH] Revise Google AI extension --- crates/extension_api/src/extension_api.rs | 27 +- .../wit/since_v0.8.0/llm-provider.wit | 27 ++ .../src/wasm_host/llm_provider.rs | 14 +- .../src/wasm_host/wit/since_v0_8_0.rs | 77 ++++- extensions/google-ai/src/google_ai.rs | 286 +++++++++++++++--- 5 files changed, 369 insertions(+), 62 deletions(-) diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index c1bd322b5d9da1ef941a8346f5d243aa2a5eec83..d888bde26d337b1e024480e6d764473c1c38e855 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -31,21 +31,24 @@ pub use wit::{ }, zed::extension::llm_provider::{ CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, - CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo, - ImageData as LlmImageData, MessageContent as LlmMessageContent, - MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, - ModelInfo as LlmModelInfo, OauthWebAuthConfig as LlmOauthWebAuthConfig, + CompletionRequest as LlmCompletionRequest, CustomModelConfig as LlmCustomModelConfig, + DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo, ImageData as LlmImageData, + MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, + ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo, + OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo, - RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, - ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, - ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, - ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, - ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, - ToolUseJsonParseError as LlmToolUseJsonParseError, + ProviderSettings as LlmProviderSettings, RequestMessage as LlmRequestMessage, + StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, + TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice, + ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat, + ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent, + ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError, delete_credential as llm_delete_credential, get_credential as llm_get_credential, - get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser, + get_env_var as llm_get_env_var, get_provider_settings as llm_get_provider_settings, + oauth_open_browser as llm_oauth_open_browser, oauth_send_http_request as llm_oauth_send_http_request, - oauth_start_web_auth as llm_oauth_start_web_auth, store_credential as llm_store_credential, + oauth_start_web_auth as llm_oauth_start_web_auth, + store_credential as llm_store_credential, }, zed::extension::nodejs::{ node_binary_path, npm_install_package, npm_package_installed_version, diff --git a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit index 1c9ce7d8ca8f22624135549594b6adc45dccda68..0d5dea54c388d7e31b2b6d6e71faa90bb12c74c1 100644 --- a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit +++ b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit @@ -312,6 +312,33 @@ interface llm-provider { /// callback differently (e.g., polling-based flows). oauth-open-browser: func(url: string) -> result<_, string>; + /// Provider settings from user configuration. + /// Extensions can use this to allow custom API URLs, custom models, etc. + record provider-settings { + /// Custom API URL override (if configured by the user). + api-url: option, + /// Custom models configured by the user. + available-models: list, + } + + /// Configuration for a custom model defined by the user. + record custom-model-config { + /// The model's API identifier. + name: string, + /// Display name for the UI. + display-name: option, + /// Maximum input token count. + max-tokens: u64, + /// Maximum output tokens (optional). + max-output-tokens: option, + /// Thinking budget for models that support extended thinking (None = auto). + thinking-budget: option, + } + + /// Get provider-specific settings configured by the user. + /// Returns settings like custom API URLs and custom model configurations. + get-provider-settings: func(provider-id: string) -> option; + /// Information needed to display the device flow prompt modal to the user. record device-flow-prompt-info { /// The user code to display (e.g., "ABC-123"). diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index a5c2c286eb4f0876fb0c4994e710dc3c0a02af81..236a63d353cd558b6f9c639501dcd2d3ecaa5c63 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -28,7 +28,8 @@ use language_model::{ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage, + LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, RateLimiter, + StopReason, TokenUsage, }; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use settings::Settings; @@ -171,6 +172,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { provider_id: self.id(), provider_name: self.name(), provider_info: self.provider_info.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) } @@ -188,6 +190,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { provider_id: self.id(), provider_name: self.name(), provider_info: self.provider_info.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) } @@ -204,6 +207,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { provider_id: self.id(), provider_name: self.name(), provider_info: self.provider_info.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -1590,6 +1594,7 @@ pub struct ExtensionLanguageModel { provider_id: LanguageModelProviderId, provider_name: LanguageModelProviderName, provider_info: LlmProviderInfo, + request_limiter: RateLimiter, } impl LanguageModel for ExtensionLanguageModel { @@ -1694,7 +1699,7 @@ impl LanguageModel for ExtensionLanguageModel { let wit_request = convert_request_to_wit(request); - async move { + let future = self.request_limiter.stream(async move { // Start the stream let stream_id_result = extension .call({ @@ -1781,8 +1786,9 @@ impl LanguageModel for ExtensionLanguageModel { ); Ok(stream.boxed()) - } - .boxed() + }); + + async move { Ok(future.await?.boxed()) }.boxed() } fn cache_configuration(&self) -> Option { diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index e9e039002f171199f12a4f209267e53b79301306..67bea1e5782c1805f4cb412b3d58894070ad2743 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -8,7 +8,7 @@ use crate::wasm_host::wit::since_v0_8_0::{ }; use crate::wasm_host::{WasmState, wit::ToWasmtimeResult}; use ::http_client::{AsyncBody, HttpRequestExt}; -use ::settings::{Settings, WorktreeId}; +use ::settings::{ModelMode, Settings, SettingsStore, WorktreeId}; use anyhow::{Context as _, Result, bail}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; @@ -1435,4 +1435,79 @@ impl llm_provider::Host for WasmState { .await .to_wasmtime_result() } + + async fn get_provider_settings( + &mut self, + provider_id: String, + ) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + + let result = self + .on_main_thread(move |cx| { + async move { + cx.update(|cx| { + let settings_store = cx.global::(); + let user_settings = settings_store.raw_user_settings(); + let language_models = user_settings + .and_then(|s| s.content.language_models.as_ref()); + + // Map provider IDs to their settings + // The provider_id from the extension is just the provider part (e.g., "google-ai") + // We need to match this to the appropriate settings + match provider_id.as_str() { + "google-ai" => { + let google = language_models.and_then(|lm| lm.google.as_ref()); + let google = google?; + + let api_url = google.api_url.clone().filter(|s| !s.is_empty()); + + let available_models = google + .available_models + .as_ref() + .map(|models| { + models + .iter() + .map(|m| { + let thinking_budget = match &m.mode { + Some(ModelMode::Thinking { budget_tokens }) => { + *budget_tokens + } + _ => None, + }; + llm_provider::CustomModelConfig { + name: m.name.clone(), + display_name: m.display_name.clone(), + max_tokens: m.max_tokens, + max_output_tokens: None, + thinking_budget, + } + }) + .collect() + }) + .unwrap_or_default(); + + Some(llm_provider::ProviderSettings { + api_url, + available_models, + }) + } + _ => { + log::debug!( + "Extension {} requested settings for unknown provider: {}", + extension_id, + provider_id + ); + None + } + } + }) + .ok() + .flatten() + } + .boxed_local() + }) + .await; + + Ok(result) + } } diff --git a/extensions/google-ai/src/google_ai.rs b/extensions/google-ai/src/google_ai.rs index 7fdd83e49e7234e8ec97730ddb85e8b607b6b698..dfa65004c0f65a80f5f7448aebfb5630a1485f0e 100644 --- a/extensions/google-ai/src/google_ai.rs +++ b/extensions/google-ai/src/google_ai.rs @@ -2,13 +2,26 @@ use std::collections::HashMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use zed_extension_api::{ - self as zed, http_client::HttpMethod, http_client::HttpRequest, llm_get_env_var, - LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmMessageContent, - LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, LlmStopReason, - LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse, + self as zed, http_client::HttpMethod, http_client::HttpRequest, + llm_get_env_var, llm_get_provider_settings, + LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmCustomModelConfig, + LlmMessageContent, LlmMessageRole, LlmModelCapabilities, LlmModelInfo, LlmProviderInfo, + LlmStopReason, LlmThinkingContent, LlmTokenUsage, LlmToolInputFormat, LlmToolUse, }; -pub const API_URL: &str = "https://generativelanguage.googleapis.com"; +pub const DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com"; + +fn get_api_url() -> String { + llm_get_provider_settings(PROVIDER_ID) + .and_then(|s| s.api_url) + .unwrap_or_else(|| DEFAULT_API_URL.to_string()) +} + +fn get_custom_models() -> Vec { + llm_get_provider_settings(PROVIDER_ID) + .map(|s| s.available_models) + .unwrap_or_default() +} fn stream_generate_content( model_id: &str, @@ -21,9 +34,10 @@ fn stream_generate_content( let generate_content_request = build_generate_content_request(model_id, request)?; validate_generate_content_request(&generate_content_request)?; + let api_url = get_api_url(); let uri = format!( "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", - API_URL, model_id, api_key + api_url, model_id, api_key ); let body = serde_json::to_vec(&generate_content_request) @@ -47,6 +61,8 @@ fn stream_generate_content( response_stream, buffer: String::new(), usage: None, + pending_events: Vec::new(), + wants_to_use_tool: false, }, ); @@ -62,9 +78,10 @@ fn count_tokens(model_id: &str, request: &LlmCompletionRequest) -> Result, + pending_events: Vec, + wants_to_use_tool: bool, } impl zed::Extension for GoogleAiExtension { @@ -240,7 +259,7 @@ fn get_api_key() -> Option { llm_get_env_var("GEMINI_API_KEY").or_else(|| llm_get_env_var("GOOGLE_AI_API_KEY")) } -fn get_models() -> Vec { +fn get_default_models() -> Vec { vec![ LlmModelInfo { id: "gemini-2.5-flash-lite".to_string(), @@ -330,6 +349,81 @@ fn get_models() -> Vec { ] } +/// Model aliases for backward compatibility with old model names. +/// Maps old names to canonical model IDs. +fn get_model_aliases() -> Vec<(&'static str, &'static str)> { + vec![ + // Gemini 2.5 Flash-Lite aliases + ("gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-lite"), + ("gemini-2.0-flash-lite-preview", "gemini-2.5-flash-lite"), + // Gemini 2.5 Flash aliases + ("gemini-2.0-flash-thinking-exp", "gemini-2.5-flash"), + ("gemini-2.5-flash-preview-04-17", "gemini-2.5-flash"), + ("gemini-2.5-flash-preview-05-20", "gemini-2.5-flash"), + ("gemini-2.5-flash-preview-latest", "gemini-2.5-flash"), + ("gemini-2.0-flash", "gemini-2.5-flash"), + // Gemini 2.5 Pro aliases + ("gemini-2.0-pro-exp", "gemini-2.5-pro"), + ("gemini-2.5-pro-preview-latest", "gemini-2.5-pro"), + ("gemini-2.5-pro-exp-03-25", "gemini-2.5-pro"), + ("gemini-2.5-pro-preview-03-25", "gemini-2.5-pro"), + ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"), + ("gemini-2.5-pro-preview-06-05", "gemini-2.5-pro"), + ] +} + +fn get_models() -> Vec { + let mut models: HashMap = HashMap::new(); + + // Add default models + for model in get_default_models() { + models.insert(model.id.clone(), model); + } + + // Add aliases as separate model entries (pointing to the same underlying model) + for (alias, canonical_id) in get_model_aliases() { + if let Some(canonical_model) = models.get(canonical_id) { + let mut alias_model = canonical_model.clone(); + alias_model.id = alias.to_string(); + alias_model.is_default = false; + alias_model.is_default_fast = false; + models.insert(alias.to_string(), alias_model); + } + } + + // Add/override with custom models from settings + for custom_model in get_custom_models() { + let model = LlmModelInfo { + id: custom_model.name.clone(), + name: custom_model.display_name.unwrap_or(custom_model.name.clone()), + max_token_count: custom_model.max_tokens, + max_output_tokens: custom_model.max_output_tokens, + capabilities: LlmModelCapabilities { + supports_images: true, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: custom_model.thinking_budget.is_some(), + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: false, + is_default_fast: false, + }; + models.insert(custom_model.name, model); + } + + models.into_values().collect() +} + +/// Get the thinking budget for a specific model from custom settings. +fn get_model_thinking_budget(model_id: &str) -> Option { + get_custom_models() + .into_iter() + .find(|m| m.name == model_id) + .and_then(|m| m.thinking_budget) +} + fn stream_generate_content_next( stream_id: &str, streams: &mut HashMap, @@ -339,6 +433,11 @@ fn stream_generate_content_next( .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; loop { + // Return any pending events first + if let Some(event) = state.pending_events.pop() { + return Ok(Some(event)); + } + if let Some(newline_pos) = state.buffer.find('\n') { let line = state.buffer[..newline_pos].to_string(); state.buffer = state.buffer[newline_pos + 1..].to_string(); @@ -348,11 +447,54 @@ fn stream_generate_content_next( continue; } - let response: GenerateContentResponse = serde_json::from_str(data) - .map_err(|e| format!("Failed to parse SSE data: {} - {}", e, data))?; + let response: GenerateContentResponse = match serde_json::from_str(data) { + Ok(response) => response, + Err(parse_error) => { + // Try to parse as an API error response + if let Ok(api_error) = serde_json::from_str::(data) { + let error_msg = api_error + .error + .message + .unwrap_or_else(|| "Unknown API error".to_string()); + let status = api_error.error.status.unwrap_or_default(); + let code = api_error.error.code.unwrap_or(0); + return Err(format!( + "Google AI API error ({}): {} [status: {}]", + code, error_msg, status + )); + } + // If it's not an error response, return the parse error + return Err(format!( + "Failed to parse SSE data: {} - {}", + parse_error, data + )); + } + }; + + // Handle prompt feedback (blocked prompts) + if let Some(ref prompt_feedback) = response.prompt_feedback { + if let Some(ref block_reason) = prompt_feedback.block_reason { + let _stop_reason = match block_reason.as_str() { + "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" + | "IMAGE_SAFETY" => LlmStopReason::Refusal, + _ => LlmStopReason::Refusal, + }; + return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::Refusal))); + } + } - if let Some(usage) = response.usage_metadata { - state.usage = Some(usage); + // Send usage updates immediately when received + if let Some(ref usage) = response.usage_metadata { + let cached_tokens = usage.cached_content_token_count.unwrap_or(0); + let prompt_tokens = usage.prompt_token_count.unwrap_or(0); + let input_tokens = prompt_tokens.saturating_sub(cached_tokens); + state.pending_events.push(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens, + output_tokens: usage.candidates_token_count.unwrap_or(0), + cache_creation_input_tokens: None, + cache_read_input_tokens: Some(cached_tokens).filter(|&c| c > 0), + })); + state.usage = Some(usage.clone()); } if let Some(candidates) = response.candidates { @@ -365,19 +507,23 @@ fn stream_generate_content_next( Part::ThoughtPart(thought_part) => { return Ok(Some(LlmCompletionEvent::Thinking( LlmThinkingContent { - text: String::new(), + text: "(Encrypted thought)".to_string(), signature: Some(thought_part.thought_signature), }, ))); } Part::FunctionCallPart(fc_part) => { + state.wants_to_use_tool = true; + // Normalize empty string signatures to None + let thought_signature = + fc_part.thought_signature.filter(|s| !s.is_empty()); return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { id: fc_part.function_call.name.clone(), name: fc_part.function_call.name, input: serde_json::to_string(&fc_part.function_call.args) .unwrap_or_default(), is_input_complete: true, - thought_signature: fc_part.thought_signature, + thought_signature, }))); } _ => {} @@ -385,23 +531,21 @@ fn stream_generate_content_next( } if let Some(finish_reason) = candidate.finish_reason { - let stop_reason = match finish_reason.as_str() { - "STOP" => LlmStopReason::EndTurn, - "MAX_TOKENS" => LlmStopReason::MaxTokens, - "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse, - "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal, - _ => LlmStopReason::EndTurn, + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP`, so we check + // wants_to_use_tool to override + let stop_reason = if state.wants_to_use_tool { + LlmStopReason::ToolUse + } else { + match finish_reason.as_str() { + "STOP" => LlmStopReason::EndTurn, + "MAX_TOKENS" => LlmStopReason::MaxTokens, + "TOOL_USE" | "FUNCTION_CALL" => LlmStopReason::ToolUse, + "SAFETY" | "RECITATION" | "OTHER" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + } }; - if let Some(usage) = state.usage.take() { - return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { - input_tokens: usage.prompt_token_count.unwrap_or(0), - output_tokens: usage.candidates_token_count.unwrap_or(0), - cache_creation_input_tokens: None, - cache_read_input_tokens: usage.cached_content_token_count, - }))); - } - return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); } } @@ -411,6 +555,28 @@ fn stream_generate_content_next( continue; } + // Check if the buffer contains a non-SSE error response (no "data: " prefix) + // This can happen when Google returns an immediate error without streaming + if !state.buffer.is_empty() + && !state.buffer.contains("data: ") + && state.buffer.contains("\"error\"") + { + // Try to parse the entire buffer as an error response + if let Ok(api_error) = serde_json::from_str::(&state.buffer) { + let error_msg = api_error + .error + .message + .unwrap_or_else(|| "Unknown API error".to_string()); + let status = api_error.error.status.unwrap_or_default(); + let code = api_error.error.code.unwrap_or(0); + streams.remove(stream_id); + return Err(format!( + "Google AI API error ({}): {} [status: {}]", + code, error_msg, status + )); + } + } + match state.response_stream.next_chunk() { Ok(Some(chunk)) => { let chunk_str = String::from_utf8_lossy(&chunk); @@ -491,12 +657,13 @@ fn build_generate_content_request( Some(request.stop_sequences.clone()) }, max_output_tokens: request.max_tokens.map(|t| t as usize), - temperature: request.temperature.map(|t| t as f64), + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), top_p: None, top_k: None, thinking_config: if request.thinking_allowed { - Some(ThinkingConfig { - thinking_budget: 8192, + // Check if this model has a custom thinking budget configured + get_model_thinking_budget(model_id).map(|thinking_budget| ThinkingConfig { + thinking_budget, }) } else { None @@ -533,29 +700,45 @@ fn convert_content_to_parts(content: &[LlmMessageContent]) -> Result, })); } LlmMessageContent::ToolUse(tool_use) => { + // Normalize empty string signatures to None + let thought_signature = tool_use + .thought_signature + .clone() + .filter(|s| !s.is_empty()); parts.push(Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { name: tool_use.name.clone(), args: serde_json::from_str(&tool_use.input).unwrap_or_default(), }, - thought_signature: tool_use.thought_signature.clone(), + thought_signature, })); } LlmMessageContent::ToolResult(tool_result) => { - let response_value = match &tool_result.content { + match &tool_result.content { zed::LlmToolResultContent::Text(text) => { - serde_json::json!({ "result": text }) + parts.push(Part::FunctionResponsePart(FunctionResponsePart { + function_response: FunctionResponse { + name: tool_result.tool_name.clone(), + response: serde_json::json!({ "output": text }), + }, + })); } - zed::LlmToolResultContent::Image(_) => { - serde_json::json!({ "error": "Image results not supported" }) + zed::LlmToolResultContent::Image(image) => { + // Send both the function response and the image inline + parts.push(Part::FunctionResponsePart(FunctionResponsePart { + function_response: FunctionResponse { + name: tool_result.tool_name.clone(), + response: serde_json::json!({ "output": "Tool responded with an image" }), + }, + })); + parts.push(Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.clone(), + }, + })); } - }; - parts.push(Part::FunctionResponsePart(FunctionResponsePart { - function_response: FunctionResponse { - name: tool_result.tool_name.clone(), - response: response_value, - }, - })); + } } LlmMessageContent::Thinking(thinking) => { if let Some(signature) = &thinking.signature { @@ -721,7 +904,7 @@ pub struct PromptFeedback { pub block_reason_message: Option, } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { #[serde(skip_serializing_if = "Option::is_none")] @@ -897,6 +1080,19 @@ impl ModelName { const MODEL_NAME_PREFIX: &str = "models/"; +/// Google API error response structure +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub error: ApiError, +} + +#[derive(Debug, Deserialize)] +pub struct ApiError { + pub code: Option, + pub message: Option, + pub status: Option, +} + impl Serialize for ModelName { fn serialize(&self, serializer: S) -> Result where