From 01a2c8eb017dfac688784ec12c7e5bc81ddab0a8 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 31 Mar 2025 18:49:59 +0200 Subject: [PATCH] Set tool schema format for zed.dev language model (#27788) Release Notes: - N/A --- crates/language_model/src/model/cloud_model.rs | 9 ++++++++- crates/language_models/src/provider/cloud.rs | 8 ++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 21627f38e2b618b6be2d2a85d0c661f5ff0ae0e3..fcb67b0c6d32306f708c92987e1ca18f36e6b133 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -14,7 +14,7 @@ use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use strum::EnumIter; use thiserror::Error; -use crate::LanguageModelAvailability; +use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "provider", rename_all = "lowercase")] @@ -113,6 +113,13 @@ impl CloudModel { }, } } + + pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self { + Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema, + Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset, + } + } } #[derive(Error, Debug)] diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index bc3cc8718103fb3c44f38b299c38f0a5b81cb088..376d5cd3f5822b3f4379175a7cbb3fa4fc5bf24d 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -15,8 +15,8 @@ use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, - ZED_CLOUD_PROVIDER_ID, + LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, + LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -559,6 +559,10 @@ impl LanguageModel for CloudLanguageModel { self.model.availability() } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + self.model.tool_input_format() + } + fn max_token_count(&self) -> usize { self.model.max_token_count() }