diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 29709490c9973b3311aa60577c12f1723a39be9f..34fbecc44bc73c3da9be8a228e20f17ac5304f36 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -931,26 +931,21 @@ impl Thread { let mut request = self.to_completion_request(cx); if model.supports_tools() { - request.tools = { - let mut tools = Vec::new(); - tools.extend( - self.tools() - .read(cx) - .enabled_tools(cx) - .into_iter() - .filter_map(|tool| { - // Skip tools that cannot be supported - let input_schema = tool.input_schema(model.tool_input_format()).ok()?; - Some(LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema, - }) - }), - ); - - tools - }; + request.tools = self + .tools() + .read(cx) + .enabled_tools(cx) + .into_iter() + .filter_map(|tool| { + // Skip tools that cannot be supported + let input_schema = tool.input_schema(model.tool_input_format()).ok()?; + Some(LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema, + }) + }) + .collect(); } self.stream_completion(request, model, window, cx); diff --git a/crates/assistant_tool/src/tool_schema.rs b/crates/assistant_tool/src/tool_schema.rs index 225c1c22efd5504b5406c832a148255bc4e0f3ea..c7d7ba1c33e4b4858ce6d7d8b831104e0a53ff55 100644 --- a/crates/assistant_tool/src/tool_schema.rs +++ b/crates/assistant_tool/src/tool_schema.rs @@ -10,6 +10,11 @@ pub fn adapt_schema_to_format( json: &mut Value, format: LanguageModelToolSchemaFormat, ) -> Result<()> { + if let Value::Object(obj) = json { + obj.remove("$schema"); + obj.remove("title"); + } + match format { LanguageModelToolSchemaFormat::JsonSchema => Ok(()), LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json), @@ -30,10 +35,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { } } - const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"]; - for key in KEYS_TO_REMOVE { - obj.remove(key); - } + obj.remove("format"); if let Some(default) = obj.get("default") { let is_null = default.is_null(); diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 1eca808f6da6d043be8faa797d153b380f79d8a8..edb435d7d0a0c01d3e8b536d532ffeb4d3bea7aa 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -110,11 +110,38 @@ pub fn init(http_client: Arc, cx: &mut App) { #[cfg(test)] mod tests { + use super::*; use client::Client; use clock::FakeSystemClock; use http_client::FakeHttpClient; + use schemars::JsonSchema; + use serde::Serialize; - use super::*; + #[test] + fn test_json_schema() { + #[derive(Serialize, JsonSchema)] + struct GetWeatherTool { + location: String, + } + + let schema = schema::json_schema_for::( + language_model::LanguageModelToolSchemaFormat::JsonSchema, + ) + .unwrap(); + + assert_eq!( + schema, + serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string" + } + }, + "required": ["location"], + }) + ); + } #[gpui::test] fn test_builtin_tool_schema_compatibility(cx: &mut App) {