More Gemini extension fixes

Richard Feldman created

Change summary

crates/extension_host/src/extension_host.rs         | 36 ++++++
crates/extension_host/src/wasm_host/llm_provider.rs | 76 +++++++-------
extensions/google-ai/extension.toml                 |  4 
extensions/google-ai/src/google_ai.rs               |  4 
4 files changed, 77 insertions(+), 43 deletions(-)

Detailed changes

crates/extension_host/src/extension_host.rs 🔗

@@ -66,12 +66,16 @@ use util::{ResultExt, paths::RemotePathBuf};
 use wasm_host::llm_provider::ExtensionLanguageModelProvider;
 use wasm_host::{
     WasmExtension, WasmHost,
-    wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
+    wit::{
+        LlmCacheConfiguration, LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version,
+        wasm_api_version_range,
+    },
 };
 
 struct LlmProviderWithModels {
     provider_info: LlmProviderInfo,
     models: Vec<LlmModelInfo>,
+    cache_configs: collections::HashMap<String, LlmCacheConfiguration>,
     is_authenticated: bool,
     icon_path: Option<SharedString>,
     auth_config: Option<extension::LanguageModelAuthConfig>,
@@ -1635,6 +1639,32 @@ impl ExtensionStore {
                                         }
                                     };
 
+                                    // Query cache configurations for each model
+                                    let mut cache_configs = collections::HashMap::default();
+                                    for model in &models {
+                                        let cache_config_result = wasm_extension
+                                            .call({
+                                                let provider_id = provider_info.id.clone();
+                                                let model_id = model.id.clone();
+                                                |ext, store| {
+                                                    async move {
+                                                        ext.call_llm_cache_configuration(
+                                                            store,
+                                                            &provider_id,
+                                                            &model_id,
+                                                        )
+                                                        .await
+                                                    }
+                                                    .boxed()
+                                                }
+                                            })
+                                            .await;
+
+                                        if let Ok(Ok(Some(config))) = cache_config_result {
+                                            cache_configs.insert(model.id.clone(), config);
+                                        }
+                                    }
+
                                     // Query initial authentication state
                                     let is_authenticated = wasm_extension
                                         .call({
@@ -1677,6 +1707,7 @@ impl ExtensionStore {
                                     llm_providers_with_models.push(LlmProviderWithModels {
                                         provider_info,
                                         models,
+                                        cache_configs,
                                         is_authenticated,
                                         icon_path,
                                         auth_config,
@@ -1776,6 +1807,7 @@ impl ExtensionStore {
                         let wasm_ext = extension.as_ref().clone();
                         let pinfo = llm_provider.provider_info.clone();
                         let mods = llm_provider.models.clone();
+                        let cache_cfgs = llm_provider.cache_configs.clone();
                         let auth = llm_provider.is_authenticated;
                         let icon = llm_provider.icon_path.clone();
                         let auth_config = llm_provider.auth_config.clone();
@@ -1784,7 +1816,7 @@ impl ExtensionStore {
                             provider_id.clone(),
                             Box::new(move |cx: &mut App| {
                                 let provider = Arc::new(ExtensionLanguageModelProvider::new(
-                                    wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
+                                    wasm_ext, pinfo, mods, cache_cfgs, auth, icon, auth_config, cx,
                                 ));
                                 language_model::LanguageModelRegistry::global(cx).update(
                                     cx,

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -5,11 +5,12 @@ use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
 use collections::HashSet;
 
 use crate::wasm_host::wit::{
-    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
-    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
-    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
-    LlmToolUse,
+    LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData,
+    LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage,
+    LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat,
+    LlmToolResult, LlmToolResultContent, LlmToolUse,
 };
+use collections::HashMap;
 use anyhow::{Result, anyhow};
 use credentials_provider::CredentialsProvider;
 use extension::{LanguageModelAuthConfig, OAuthConfig};
@@ -58,6 +59,8 @@ pub struct ExtensionLanguageModelProvider {
 pub struct ExtensionLlmProviderState {
     is_authenticated: bool,
     available_models: Vec<LlmModelInfo>,
+    /// Cache configurations for each model, keyed by model ID.
+    cache_configs: HashMap<String, LlmCacheConfiguration>,
     /// Set of env var names that are allowed to be read for this provider.
     allowed_env_vars: HashSet<String>,
     /// If authenticated via env var, which one was used.
@@ -71,6 +74,7 @@ impl ExtensionLanguageModelProvider {
         extension: WasmExtension,
         provider_info: LlmProviderInfo,
         models: Vec<LlmModelInfo>,
+        cache_configs: HashMap<String, LlmCacheConfiguration>,
         is_authenticated: bool,
         icon_path: Option<SharedString>,
         auth_config: Option<LanguageModelAuthConfig>,
@@ -118,6 +122,7 @@ impl ExtensionLanguageModelProvider {
         let state = cx.new(|_| ExtensionLlmProviderState {
             is_authenticated,
             available_models: models,
+            cache_configs,
             allowed_env_vars,
             env_var_name_used,
         });
@@ -139,6 +144,30 @@ impl ExtensionLanguageModelProvider {
     fn credential_key(&self) -> String {
         format!("extension-llm-{}", self.provider_id_string())
     }
+
+    fn create_model(
+        &self,
+        model_info: &LlmModelInfo,
+        cache_configs: &HashMap<String, LlmCacheConfiguration>,
+    ) -> Arc<dyn LanguageModel> {
+        let cache_config = cache_configs.get(&model_info.id).map(|config| {
+            LanguageModelCacheConfiguration {
+                max_cache_anchors: config.max_cache_anchors as usize,
+                should_speculate: false,
+                min_total_token: config.min_total_token_count,
+            }
+        });
+
+        Arc::new(ExtensionLanguageModel {
+            extension: self.extension.clone(),
+            model_info: model_info.clone(),
+            provider_id: self.id(),
+            provider_name: self.name(),
+            provider_info: self.provider_info.clone(),
+            request_limiter: RateLimiter::new(4),
+            cache_config,
+        })
+    }
 }
 
 impl LanguageModelProvider for ExtensionLanguageModelProvider {
@@ -165,16 +194,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
             .iter()
             .find(|m| m.is_default)
             .or_else(|| state.available_models.first())
-            .map(|model_info| {
-                Arc::new(ExtensionLanguageModel {
-                    extension: self.extension.clone(),
-                    model_info: model_info.clone(),
-                    provider_id: self.id(),
-                    provider_name: self.name(),
-                    provider_info: self.provider_info.clone(),
-                    request_limiter: RateLimiter::new(4),
-                }) as Arc<dyn LanguageModel>
-            })
+            .map(|model_info| self.create_model(model_info, &state.cache_configs))
     }
 
     fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -183,16 +203,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
             .available_models
             .iter()
             .find(|m| m.is_default_fast)
-            .map(|model_info| {
-                Arc::new(ExtensionLanguageModel {
-                    extension: self.extension.clone(),
-                    model_info: model_info.clone(),
-                    provider_id: self.id(),
-                    provider_name: self.name(),
-                    provider_info: self.provider_info.clone(),
-                    request_limiter: RateLimiter::new(4),
-                }) as Arc<dyn LanguageModel>
-            })
+            .map(|model_info| self.create_model(model_info, &state.cache_configs))
     }
 
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -200,16 +211,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
         state
             .available_models
             .iter()
-            .map(|model_info| {
-                Arc::new(ExtensionLanguageModel {
-                    extension: self.extension.clone(),
-                    model_info: model_info.clone(),
-                    provider_id: self.id(),
-                    provider_name: self.name(),
-                    provider_info: self.provider_info.clone(),
-                    request_limiter: RateLimiter::new(4),
-                }) as Arc<dyn LanguageModel>
-            })
+            .map(|model_info| self.create_model(model_info, &state.cache_configs))
             .collect()
     }
 
@@ -1595,6 +1597,7 @@ pub struct ExtensionLanguageModel {
     provider_name: LanguageModelProviderName,
     provider_info: LlmProviderInfo,
     request_limiter: RateLimiter,
+    cache_config: Option<LanguageModelCacheConfiguration>,
 }
 
 impl LanguageModel for ExtensionLanguageModel {
@@ -1615,7 +1618,7 @@ impl LanguageModel for ExtensionLanguageModel {
     }
 
     fn telemetry_id(&self) -> String {
-        format!("extension-{}", self.model_info.id)
+        format!("{}/{}", self.provider_info.id, self.model_info.id)
     }
 
     fn supports_images(&self) -> bool {
@@ -1795,8 +1798,7 @@ impl LanguageModel for ExtensionLanguageModel {
     }
 
     fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
-        // Extensions can implement this via llm_cache_configuration
-        None
+        self.cache_config.clone()
     }
 }
 

extensions/google-ai/extension.toml 🔗

@@ -6,8 +6,8 @@ schema_version = 1
 authors = ["Zed Team"]
 repository = "https://github.com/zed-industries/zed"
 
-[language_model_providers.google-ai]
+[language_model_providers.google]
 name = "Google AI"
 
-[language_model_providers.google-ai.auth]
+[language_model_providers.google.auth]
 env_vars = ["GEMINI_API_KEY", "GOOGLE_AI_API_KEY"]

extensions/google-ai/src/google_ai.rs 🔗

@@ -128,7 +128,7 @@ fn validate_generate_content_request(request: &GenerateContentRequest) -> Result
 
 // Extension implementation
 
-const PROVIDER_ID: &str = "google-ai";
+const PROVIDER_ID: &str = "google";
 const PROVIDER_NAME: &str = "Google AI";
 
 struct GoogleAiExtension {
@@ -343,7 +343,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_auto: true,
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
-                supports_thinking: true,
+                supports_thinking: false,
                 tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: false,