Detailed changes
@@ -844,13 +844,20 @@ impl Thread {
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
- tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
- LanguageModelRequestTool {
- name: tool.name(),
- description: tool.description(),
- input_schema: tool.input_schema(model.tool_input_format()),
- }
- }));
+ tools.extend(
+ self.tools()
+ .enabled_tools(cx)
+ .into_iter()
+ .filter_map(|tool| {
+ // Skip tools that cannot be supported
+ let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
+ Some(LanguageModelRequestTool {
+ name: tool.name(),
+ description: tool.description(),
+ input_schema,
+ })
+ }),
+ );
tools
};
@@ -1,5 +1,6 @@
mod action_log;
mod tool_registry;
+mod tool_schema;
mod tool_working_set;
use std::fmt;
@@ -16,6 +17,7 @@ use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
+pub use crate::tool_schema::*;
pub use crate::tool_working_set::*;
pub fn init(cx: &mut App) {
@@ -51,8 +53,8 @@ pub trait Tool: 'static + Send + Sync {
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
/// Returns the JSON schema that describes the tool's input.
- fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
- serde_json::Value::Object(serde_json::Map::default())
+ fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
+ Ok(serde_json::Value::Object(serde_json::Map::default()))
}
/// Returns markdown to be displayed in the UI for this tool.
@@ -0,0 +1,236 @@
+use anyhow::Result;
+use serde_json::Value;
+
+use crate::LanguageModelToolSchemaFormat;
+
+/// Tries to adapt a JSON schema representation to be compatible with the specified format.
+///
+/// If the json cannot be made compatible with the specified format, an error is returned.
+pub fn adapt_schema_to_format(
+ json: &mut Value,
+ format: LanguageModelToolSchemaFormat,
+) -> Result<()> {
+ match format {
+ LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
+ LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
+ }
+}
+
+/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
+fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
+ if let Value::Object(obj) = json {
+ const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
+
+ for key in UNSUPPORTED_KEYS {
+ if obj.contains_key(key) {
+ return Err(anyhow::anyhow!(
+ "Schema cannot be made compatible because it contains \"{}\" ",
+ key
+ ));
+ }
+ }
+
+ const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
+ for key in KEYS_TO_REMOVE {
+ obj.remove(key);
+ }
+
+ if let Some(default) = obj.get("default") {
+ let is_null = default.is_null();
+ // Default is not supported, so we need to remove it
+ obj.remove("default");
+ if is_null {
+ obj.insert("nullable".to_string(), Value::Bool(true));
+ }
+ }
+
+ // If a type is not specified for an input parameter, add a default type
+ if obj.contains_key("description")
+ && !obj.contains_key("type")
+ && !(obj.contains_key("anyOf")
+ || obj.contains_key("oneOf")
+ || obj.contains_key("allOf"))
+ {
+ obj.insert("type".to_string(), Value::String("string".to_string()));
+ }
+
+ // Handle oneOf -> anyOf conversion
+ if let Some(subschemas) = obj.get_mut("oneOf") {
+ if subschemas.is_array() {
+ let subschemas_clone = subschemas.clone();
+ obj.remove("oneOf");
+ obj.insert("anyOf".to_string(), subschemas_clone);
+ }
+ }
+
+ // Recursively process all nested objects and arrays
+ for (_, value) in obj.iter_mut() {
+ if let Value::Object(_) | Value::Array(_) = value {
+ adapt_to_json_schema_subset(value)?;
+ }
+ }
+ } else if let Value::Array(arr) = json {
+ for item in arr.iter_mut() {
+ adapt_to_json_schema_subset(item)?;
+ }
+ }
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use serde_json::json;
+
+ #[test]
+ fn test_transform_default_null_to_nullable() {
+ let mut json = json!({
+ "description": "A test field",
+ "type": "string",
+ "default": null
+ });
+
+ adapt_to_json_schema_subset(&mut json).unwrap();
+
+ assert_eq!(
+ json,
+ json!({
+ "description": "A test field",
+ "type": "string",
+ "nullable": true
+ })
+ );
+ }
+
+ #[test]
+ fn test_transform_adds_type_when_missing() {
+ let mut json = json!({
+ "description": "A test field without type"
+ });
+
+ adapt_to_json_schema_subset(&mut json).unwrap();
+
+ assert_eq!(
+ json,
+ json!({
+ "description": "A test field without type",
+ "type": "string"
+ })
+ );
+ }
+
+ #[test]
+ fn test_transform_removes_format() {
+ let mut json = json!({
+ "description": "A test field",
+ "type": "integer",
+ "format": "uint32"
+ });
+
+ adapt_to_json_schema_subset(&mut json).unwrap();
+
+ assert_eq!(
+ json,
+ json!({
+ "description": "A test field",
+ "type": "integer"
+ })
+ );
+ }
+
+ #[test]
+ fn test_transform_one_of_to_any_of() {
+ let mut json = json!({
+ "description": "A test field",
+ "oneOf": [
+ { "type": "string" },
+ { "type": "integer" }
+ ]
+ });
+
+ adapt_to_json_schema_subset(&mut json).unwrap();
+
+ assert_eq!(
+ json,
+ json!({
+ "description": "A test field",
+ "anyOf": [
+ { "type": "string" },
+ { "type": "integer" }
+ ]
+ })
+ );
+ }
+
+ #[test]
+ fn test_transform_nested_objects() {
+ let mut json = json!({
+ "type": "object",
+ "properties": {
+ "nested": {
+ "oneOf": [
+ { "type": "string" },
+ { "type": "null" }
+ ],
+ "format": "email"
+ }
+ }
+ });
+
+ adapt_to_json_schema_subset(&mut json).unwrap();
+
+ assert_eq!(
+ json,
+ json!({
+ "type": "object",
+ "properties": {
+ "nested": {
+ "anyOf": [
+ { "type": "string" },
+ { "type": "null" }
+ ]
+ }
+ }
+ })
+ );
+ }
+
+ #[test]
+ fn test_transform_fails_if_unsupported_keys_exist() {
+ let mut json = json!({
+ "type": "object",
+ "properties": {
+ "$ref": "#/definitions/User",
+ }
+ });
+
+ assert!(adapt_to_json_schema_subset(&mut json).is_err());
+
+ let mut json = json!({
+ "type": "object",
+ "properties": {
+ "if": "...",
+ }
+ });
+
+ assert!(adapt_to_json_schema_subset(&mut json).is_err());
+
+ let mut json = json!({
+ "type": "object",
+ "properties": {
+ "then": "...",
+ }
+ });
+
+ assert!(adapt_to_json_schema_subset(&mut json).is_err());
+
+ let mut json = json!({
+ "type": "object",
+ "properties": {
+ "else": "...",
+ }
+ });
+
+ assert!(adapt_to_json_schema_subset(&mut json).is_err());
+ }
+}
@@ -84,7 +84,7 @@ mod tests {
use super::*;
#[gpui::test]
- fn test_tool_schema_compatibility(cx: &mut App) {
+ fn test_builtin_tool_schema_compatibility(cx: &mut App) {
crate::init(
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(),
@@ -95,18 +95,23 @@ mod tests {
);
for tool in ToolRegistry::global(cx).tools() {
- let schema =
- tool.input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset);
- assert!(schema.is_object());
- if schema.as_object().unwrap().contains_key("$schema") {
- let error_message = format!(
- "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
- Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
- tool.name()
- );
+ let actual_schema = tool
+ .input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
+ .unwrap();
+ let mut expected_schema = actual_schema.clone();
+ assistant_tool::adapt_schema_to_format(
+ &mut expected_schema,
+ language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
+ )
+ .unwrap();
- panic!("{}", error_message)
- }
+ let error_message = format!(
+ "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
+ Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
+ tool.name(),
+ );
+
+ assert_eq!(actual_schema, expected_schema, "{}", error_message)
}
}
}
@@ -172,7 +172,7 @@ impl Tool for BatchTool {
IconName::Cog
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<BatchToolInput>(format)
}
@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
-use language_model::LanguageModelRequestMessage;
+use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
use regex::Regex;
use schemars::JsonSchema;
@@ -97,10 +97,7 @@ impl Tool for CodeActionTool {
IconName::Wand
}
- fn input_schema(
- &self,
- format: language_model::LanguageModelToolSchemaFormat,
- ) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeActionToolInput>(format)
}
@@ -91,7 +91,7 @@ impl Tool for CodeSymbolsTool {
IconName::Code
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeSymbolsInput>(format)
}
@@ -55,7 +55,7 @@ impl Tool for CopyPathTool {
IconName::Clipboard
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CopyPathToolInput>(format)
}
@@ -45,7 +45,7 @@ impl Tool for CreateDirectoryTool {
IconName::Folder
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CreateDirectoryToolInput>(format)
}
@@ -52,7 +52,7 @@ impl Tool for CreateFileTool {
IconName::FileCreate
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CreateFileToolInput>(format)
}
@@ -45,7 +45,7 @@ impl Tool for DeletePathTool {
IconName::FileDelete
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<DeletePathToolInput>(format)
}
@@ -58,7 +58,7 @@ impl Tool for DiagnosticsTool {
IconName::XCircle
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<DiagnosticsToolInput>(format)
}
@@ -128,7 +128,7 @@ impl Tool for FetchTool {
IconName::Globe
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FetchToolInput>(format)
}
@@ -151,7 +151,7 @@ impl Tool for FindReplaceFileTool {
IconName::Pencil
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindReplaceFileToolInput>(format)
}
@@ -56,7 +56,7 @@ impl Tool for ListDirectoryTool {
IconName::Folder
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ListDirectoryToolInput>(format)
}
@@ -54,7 +54,7 @@ impl Tool for MovePathTool {
IconName::ArrowRightLeft
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<MovePathToolInput>(format)
}
@@ -45,7 +45,7 @@ impl Tool for NowTool {
IconName::Info
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<NowToolInput>(format)
}
@@ -35,7 +35,7 @@ impl Tool for OpenTool {
IconName::ArrowUpRight
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<OpenToolInput>(format)
}
@@ -53,7 +53,7 @@ impl Tool for PathSearchTool {
IconName::SearchCode
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<PathSearchToolInput>(format)
}
@@ -63,7 +63,7 @@ impl Tool for ReadFileTool {
IconName::FileSearch
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ReadFileToolInput>(format)
}
@@ -60,7 +60,7 @@ impl Tool for RegexSearchTool {
IconName::Regex
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RegexSearchToolInput>(format)
}
@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
-use language_model::LanguageModelRequestMessage;
+use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -68,10 +68,7 @@ impl Tool for RenameTool {
IconName::Pencil
}
- fn input_schema(
- &self,
- format: language_model::LanguageModelToolSchemaFormat,
- ) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RenameToolInput>(format)
}
@@ -5,23 +5,20 @@ use schemars::{
schema::{RootSchema, Schema, SchemaObject},
};
-pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+pub fn json_schema_for<T: JsonSchema>(
+ format: LanguageModelToolSchemaFormat,
+) -> Result<serde_json::Value> {
let schema = root_schema_for::<T>(format);
- schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
+ schema_to_json(&schema, format)
}
-pub fn schema_to_json(
+fn schema_to_json(
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
- match format {
- LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
- LanguageModelToolSchemaFormat::JsonSchemaSubset => {
- transform_fields_to_json_schema_subset(&mut value);
- Ok(value)
- }
- }
+ assistant_tool::adapt_schema_to_format(&mut value, format)?;
+ Ok(value)
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
@@ -79,42 +76,3 @@ impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
schemars::visit::visit_schema_object(self, schema)
}
}
-
-fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
- if let serde_json::Value::Object(obj) = json {
- if let Some(default) = obj.get("default") {
- let is_null = default.is_null();
- //Default is not supported, so we need to remove it.
- obj.remove("default");
- if is_null {
- obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
- }
- }
-
- // If a type is not specified for an input parameter we need to add it.
- if obj.contains_key("description")
- && !obj.contains_key("type")
- && !(obj.contains_key("anyOf")
- || obj.contains_key("oneOf")
- || obj.contains_key("allOf"))
- {
- obj.insert(
- "type".to_string(),
- serde_json::Value::String("string".to_string()),
- );
- }
-
- //Format field is only partially supported (e.g. not uint compatibility)
- obj.remove("format");
-
- for (_, value) in obj.iter_mut() {
- if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
- transform_fields_to_json_schema_subset(value);
- }
- }
- } else if let serde_json::Value::Array(arr) = json {
- for item in arr.iter_mut() {
- transform_fields_to_json_schema_subset(item);
- }
- }
-}
@@ -84,7 +84,7 @@ impl Tool for SymbolInfoTool {
IconName::Code
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<SymbolInfoToolInput>(format)
}
@@ -44,7 +44,7 @@ impl Tool for TerminalTool {
IconName::Terminal
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<TerminalToolInput>(format)
}
@@ -36,7 +36,7 @@ impl Tool for ThinkingTool {
IconName::LightBulb
}
- fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ThinkingToolInput>(format)
}
@@ -53,16 +53,18 @@ impl Tool for ContextServerTool {
true
}
- fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
- match &self.tool.input_schema {
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
+ let mut schema = self.tool.input_schema.clone();
+ assistant_tool::adapt_schema_to_format(&mut schema, format)?;
+ Ok(match schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
}
serde_json::Value::Object(map) if map.is_empty() => {
serde_json::json!({ "type": "object", "properties": [] })
}
- _ => self.tool.input_schema.clone(),
- }
+ _ => schema,
+ })
}
fn ui_text(&self, _input: &serde_json::Value) -> String {