@@ -14,7 +14,7 @@ use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use strum::EnumIter;
use thiserror::Error;
-use crate::LanguageModelAvailability;
+use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "provider", rename_all = "lowercase")]
@@ -113,6 +113,13 @@ impl CloudModel {
},
}
}
+
+ pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
+ match self {
+ Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
+ Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
+ }
+ }
}
#[derive(Error, Debug)]
@@ -15,8 +15,8 @@ use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter,
- ZED_CLOUD_PROVIDER_ID,
+ LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
+ LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -559,6 +559,10 @@ impl LanguageModel for CloudLanguageModel {
self.model.availability()
}
+ fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
+ self.model.tool_input_format()
+ }
+
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}