More minor Google AI fixes

Richard Feldman created

Change summary

crates/extension_api/wit/since_v0.8.0/llm-provider.wit |  3 +
crates/extension_host/src/wasm_host/llm_provider.rs    |  3 +
extensions/google-ai/src/google_ai.rs                  | 20 +++++++----
3 files changed, 19 insertions(+), 7 deletions(-)

Detailed changes

crates/extension_api/wit/since_v0.8.0/llm-provider.wit 🔗

@@ -33,6 +33,9 @@ interface llm-provider {
     enum tool-input-format {
         /// Standard JSON Schema format.
         json-schema,
+        /// A subset of JSON Schema supported by Google AI.
+        /// See https://ai.google.dev/api/caching#Schema
+        json-schema-subset,
         /// Simplified schema format for certain providers.
         simplified,
     }

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -1637,6 +1637,9 @@ impl LanguageModel for ExtensionLanguageModel {
     fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
         match self.model_info.capabilities.tool_input_format {
             LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
+            LlmToolInputFormat::JsonSchemaSubset => {
+                LanguageModelToolSchemaFormat::JsonSchemaSubset
+            }
             LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
         }
     }

extensions/google-ai/src/google_ai.rs 🔗

@@ -1,4 +1,7 @@
 use std::collections::HashMap;
+use std::sync::atomic::{AtomicU64, Ordering};
+
+static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 
 use serde::{Deserialize, Deserializer, Serialize, Serializer};
 use zed_extension_api::{
@@ -273,7 +276,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
                 supports_thinking: true,
-                tool_input_format: LlmToolInputFormat::JsonSchema,
+                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: false,
             is_default_fast: true,
@@ -290,7 +293,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
                 supports_thinking: true,
-                tool_input_format: LlmToolInputFormat::JsonSchema,
+                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: true,
             is_default_fast: false,
@@ -307,7 +310,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
                 supports_thinking: true,
-                tool_input_format: LlmToolInputFormat::JsonSchema,
+                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: false,
             is_default_fast: false,
@@ -324,7 +327,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
                 supports_thinking: true,
-                tool_input_format: LlmToolInputFormat::JsonSchema,
+                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: false,
             is_default_fast: false,
@@ -341,7 +344,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
                 supports_thinking: true,
-                tool_input_format: LlmToolInputFormat::JsonSchema,
+                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: false,
             is_default_fast: false,
@@ -405,7 +408,7 @@ fn get_models() -> Vec<LlmModelInfo> {
                 supports_tool_choice_any: true,
                 supports_tool_choice_none: true,
                 supports_thinking: custom_model.thinking_budget.is_some(),
-                tool_input_format: LlmToolInputFormat::JsonSchema,
+                tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
             },
             is_default: false,
             is_default_fast: false,
@@ -517,8 +520,11 @@ fn stream_generate_content_next(
                                     // Normalize empty string signatures to None
                                     let thought_signature =
                                         fc_part.thought_signature.filter(|s| !s.is_empty());
+                                    // Generate unique tool use ID like hardcoded implementation
+                                    let next_tool_id = TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
+                                    let tool_use_id = format!("{}-{}", fc_part.function_call.name, next_tool_id);
                                     return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
-                                        id: fc_part.function_call.name.clone(),
+                                        id: tool_use_id,
                                         name: fc_part.function_call.name,
                                         input: serde_json::to_string(&fc_part.function_call.args)
                                             .unwrap_or_default(),