From ca8279ca79ec6155d1019c8cb4f412c07c5fcd9e Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Wed, 17 Dec 2025 18:21:02 -0500 Subject: [PATCH] More Gemini extension fixes --- crates/extension_host/src/extension_host.rs | 36 ++++++++- .../src/wasm_host/llm_provider.rs | 76 ++++++++++--------- extensions/google-ai/extension.toml | 4 +- extensions/google-ai/src/google_ai.rs | 4 +- 4 files changed, 77 insertions(+), 43 deletions(-) diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index b0f7fdfde57950697817c28b7e4203a71615a61d..ff4cb5985763a0d250c2abcadf621e862311efd3 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -66,12 +66,16 @@ use util::{ResultExt, paths::RemotePathBuf}; use wasm_host::llm_provider::ExtensionLanguageModelProvider; use wasm_host::{ WasmExtension, WasmHost, - wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range}, + wit::{ + LlmCacheConfiguration, LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, + wasm_api_version_range, + }, }; struct LlmProviderWithModels { provider_info: LlmProviderInfo, models: Vec, + cache_configs: collections::HashMap, is_authenticated: bool, icon_path: Option, auth_config: Option, @@ -1635,6 +1639,32 @@ impl ExtensionStore { } }; + // Query cache configurations for each model + let mut cache_configs = collections::HashMap::default(); + for model in &models { + let cache_config_result = wasm_extension + .call({ + let provider_id = provider_info.id.clone(); + let model_id = model.id.clone(); + |ext, store| { + async move { + ext.call_llm_cache_configuration( + store, + &provider_id, + &model_id, + ) + .await + } + .boxed() + } + }) + .await; + + if let Ok(Ok(Some(config))) = cache_config_result { + cache_configs.insert(model.id.clone(), config); + } + } + // Query initial authentication state let is_authenticated = wasm_extension .call({ @@ -1677,6 +1707,7 @@ impl ExtensionStore { llm_providers_with_models.push(LlmProviderWithModels { provider_info, models, + cache_configs, is_authenticated, icon_path, auth_config, @@ -1776,6 +1807,7 @@ impl ExtensionStore { let wasm_ext = extension.as_ref().clone(); let pinfo = llm_provider.provider_info.clone(); let mods = llm_provider.models.clone(); + let cache_cfgs = llm_provider.cache_configs.clone(); let auth = llm_provider.is_authenticated; let icon = llm_provider.icon_path.clone(); let auth_config = llm_provider.auth_config.clone(); @@ -1784,7 +1816,7 @@ impl ExtensionStore { provider_id.clone(), Box::new(move |cx: &mut App| { let provider = Arc::new(ExtensionLanguageModelProvider::new( - wasm_ext, pinfo, mods, auth, icon, auth_config, cx, + wasm_ext, pinfo, mods, cache_cfgs, auth, icon, auth_config, cx, )); language_model::LanguageModelRegistry::global(cx).update( cx, diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index ca3fd780a83ee704ad6c080984965d4f0aa3f222..6ea6198f7b3f14430e22dae63e9ac1123987cb14 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -5,11 +5,12 @@ use crate::wasm_host::wit::LlmDeviceFlowPromptInfo; use collections::HashSet; use crate::wasm_host::wit::{ - LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole, - LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent, - LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent, - LlmToolUse, + LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData, + LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage, + LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, + LlmToolResult, LlmToolResultContent, LlmToolUse, }; +use collections::HashMap; use anyhow::{Result, anyhow}; use credentials_provider::CredentialsProvider; use extension::{LanguageModelAuthConfig, OAuthConfig}; @@ -58,6 +59,8 @@ pub struct ExtensionLanguageModelProvider { pub struct ExtensionLlmProviderState { is_authenticated: bool, available_models: Vec, + /// Cache configurations for each model, keyed by model ID. + cache_configs: HashMap, /// Set of env var names that are allowed to be read for this provider. allowed_env_vars: HashSet, /// If authenticated via env var, which one was used. @@ -71,6 +74,7 @@ impl ExtensionLanguageModelProvider { extension: WasmExtension, provider_info: LlmProviderInfo, models: Vec, + cache_configs: HashMap, is_authenticated: bool, icon_path: Option, auth_config: Option, @@ -118,6 +122,7 @@ impl ExtensionLanguageModelProvider { let state = cx.new(|_| ExtensionLlmProviderState { is_authenticated, available_models: models, + cache_configs, allowed_env_vars, env_var_name_used, }); @@ -139,6 +144,30 @@ impl ExtensionLanguageModelProvider { fn credential_key(&self) -> String { format!("extension-llm-{}", self.provider_id_string()) } + + fn create_model( + &self, + model_info: &LlmModelInfo, + cache_configs: &HashMap, + ) -> Arc { + let cache_config = cache_configs.get(&model_info.id).map(|config| { + LanguageModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors as usize, + should_speculate: false, + min_total_token: config.min_total_token_count, + } + }); + + Arc::new(ExtensionLanguageModel { + extension: self.extension.clone(), + model_info: model_info.clone(), + provider_id: self.id(), + provider_name: self.name(), + provider_info: self.provider_info.clone(), + request_limiter: RateLimiter::new(4), + cache_config, + }) + } } impl LanguageModelProvider for ExtensionLanguageModelProvider { @@ -165,16 +194,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { .iter() .find(|m| m.is_default) .or_else(|| state.available_models.first()) - .map(|model_info| { - Arc::new(ExtensionLanguageModel { - extension: self.extension.clone(), - model_info: model_info.clone(), - provider_id: self.id(), - provider_name: self.name(), - provider_info: self.provider_info.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model_info| self.create_model(model_info, &state.cache_configs)) } fn default_fast_model(&self, cx: &App) -> Option> { @@ -183,16 +203,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { .available_models .iter() .find(|m| m.is_default_fast) - .map(|model_info| { - Arc::new(ExtensionLanguageModel { - extension: self.extension.clone(), - model_info: model_info.clone(), - provider_id: self.id(), - provider_name: self.name(), - provider_info: self.provider_info.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model_info| self.create_model(model_info, &state.cache_configs)) } fn provided_models(&self, cx: &App) -> Vec> { @@ -200,16 +211,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { state .available_models .iter() - .map(|model_info| { - Arc::new(ExtensionLanguageModel { - extension: self.extension.clone(), - model_info: model_info.clone(), - provider_id: self.id(), - provider_name: self.name(), - provider_info: self.provider_info.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model_info| self.create_model(model_info, &state.cache_configs)) .collect() } @@ -1595,6 +1597,7 @@ pub struct ExtensionLanguageModel { provider_name: LanguageModelProviderName, provider_info: LlmProviderInfo, request_limiter: RateLimiter, + cache_config: Option, } impl LanguageModel for ExtensionLanguageModel { @@ -1615,7 +1618,7 @@ impl LanguageModel for ExtensionLanguageModel { } fn telemetry_id(&self) -> String { - format!("extension-{}", self.model_info.id) + format!("{}/{}", self.provider_info.id, self.model_info.id) } fn supports_images(&self) -> bool { @@ -1795,8 +1798,7 @@ impl LanguageModel for ExtensionLanguageModel { } fn cache_configuration(&self) -> Option { - // Extensions can implement this via llm_cache_configuration - None + self.cache_config.clone() } } diff --git a/extensions/google-ai/extension.toml b/extensions/google-ai/extension.toml index aebe57f396f39e971f8d7647299ccad75448cce9..9b3a554ab8c9862e26b9618112133d11a194e43c 100644 --- a/extensions/google-ai/extension.toml +++ b/extensions/google-ai/extension.toml @@ -6,8 +6,8 @@ schema_version = 1 authors = ["Zed Team"] repository = "https://github.com/zed-industries/zed" -[language_model_providers.google-ai] +[language_model_providers.google] name = "Google AI" -[language_model_providers.google-ai.auth] +[language_model_providers.google.auth] env_vars = ["GEMINI_API_KEY", "GOOGLE_AI_API_KEY"] diff --git a/extensions/google-ai/src/google_ai.rs b/extensions/google-ai/src/google_ai.rs index 808e15c6576802ba2ced79db775cd264284e1498..b8198694284030eabdd55ab6742c69d7754ae876 100644 --- a/extensions/google-ai/src/google_ai.rs +++ b/extensions/google-ai/src/google_ai.rs @@ -128,7 +128,7 @@ fn validate_generate_content_request(request: &GenerateContentRequest) -> Result // Extension implementation -const PROVIDER_ID: &str = "google-ai"; +const PROVIDER_ID: &str = "google"; const PROVIDER_NAME: &str = "Google AI"; struct GoogleAiExtension { @@ -343,7 +343,7 @@ fn get_default_models() -> Vec { supports_tool_choice_auto: true, supports_tool_choice_any: true, supports_tool_choice_none: true, - supports_thinking: true, + supports_thinking: false, tool_input_format: LlmToolInputFormat::JsonSchemaSubset, }, is_default: false,