tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 25    // `additionalProperties` defaults to `false` unless explicitly specified.
 26    // This prevents models from hallucinating tool parameters.
 27    if let Value::Object(obj) = json {
 28        if let Some(Value::String(type_str)) = obj.get("type") {
 29            if type_str == "object" && !obj.contains_key("additionalProperties") {
 30                obj.insert("additionalProperties".to_string(), Value::Bool(false));
 31            }
 32        }
 33    }
 34    Ok(())
 35}
 36
 37/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 38fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 39    if let Value::Object(obj) = json {
 40        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 41
 42        for key in UNSUPPORTED_KEYS {
 43            anyhow::ensure!(
 44                !obj.contains_key(key),
 45                "Schema cannot be made compatible because it contains \"{key}\""
 46            );
 47        }
 48
 49        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
 50            ("format", |value| value.is_string()),
 51            ("additionalProperties", |value| value.is_boolean()),
 52            ("exclusiveMinimum", |value| value.is_number()),
 53            ("exclusiveMaximum", |value| value.is_number()),
 54            ("optional", |value| value.is_boolean()),
 55        ];
 56        for (key, predicate) in KEYS_TO_REMOVE {
 57            if let Some(value) = obj.get(key) {
 58                if predicate(value) {
 59                    obj.remove(key);
 60                }
 61            }
 62        }
 63
 64        // If a type is not specified for an input parameter, add a default type
 65        if matches!(obj.get("description"), Some(Value::String(_)))
 66            && !obj.contains_key("type")
 67            && !(obj.contains_key("anyOf")
 68                || obj.contains_key("oneOf")
 69                || obj.contains_key("allOf"))
 70        {
 71            obj.insert("type".to_string(), Value::String("string".to_string()));
 72        }
 73
 74        // Handle oneOf -> anyOf conversion
 75        if let Some(subschemas) = obj.get_mut("oneOf") {
 76            if subschemas.is_array() {
 77                let subschemas_clone = subschemas.clone();
 78                obj.remove("oneOf");
 79                obj.insert("anyOf".to_string(), subschemas_clone);
 80            }
 81        }
 82
 83        // Recursively process all nested objects and arrays
 84        for (_, value) in obj.iter_mut() {
 85            if let Value::Object(_) | Value::Array(_) = value {
 86                adapt_to_json_schema_subset(value)?;
 87            }
 88        }
 89    } else if let Value::Array(arr) = json {
 90        for item in arr.iter_mut() {
 91            adapt_to_json_schema_subset(item)?;
 92        }
 93    }
 94    Ok(())
 95}
 96
 97#[cfg(test)]
 98mod tests {
 99    use super::*;
100    use serde_json::json;
101
102    #[test]
103    fn test_transform_adds_type_when_missing() {
104        let mut json = json!({
105            "description": "A test field without type"
106        });
107
108        adapt_to_json_schema_subset(&mut json).unwrap();
109
110        assert_eq!(
111            json,
112            json!({
113                "description": "A test field without type",
114                "type": "string"
115            })
116        );
117
118        // Ensure that we do not add a type if it is an object
119        let mut json = json!({
120            "description": {
121                "value": "abc",
122                "type": "string"
123            }
124        });
125
126        adapt_to_json_schema_subset(&mut json).unwrap();
127
128        assert_eq!(
129            json,
130            json!({
131                "description": {
132                    "value": "abc",
133                    "type": "string"
134                }
135            })
136        );
137    }
138
139    #[test]
140    fn test_transform_removes_unsupported_keys() {
141        let mut json = json!({
142            "description": "A test field",
143            "type": "integer",
144            "format": "uint32",
145            "exclusiveMinimum": 0,
146            "exclusiveMaximum": 100,
147            "additionalProperties": false,
148            "optional": true
149        });
150
151        adapt_to_json_schema_subset(&mut json).unwrap();
152
153        assert_eq!(
154            json,
155            json!({
156                "description": "A test field",
157                "type": "integer"
158            })
159        );
160
161        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
162        let mut json = json!({
163            "description": "A test field",
164            "type": "integer",
165            "format": {},
166        });
167
168        adapt_to_json_schema_subset(&mut json).unwrap();
169
170        assert_eq!(
171            json,
172            json!({
173                "description": "A test field",
174                "type": "integer",
175                "format": {},
176            })
177        );
178    }
179
180    #[test]
181    fn test_transform_one_of_to_any_of() {
182        let mut json = json!({
183            "description": "A test field",
184            "oneOf": [
185                { "type": "string" },
186                { "type": "integer" }
187            ]
188        });
189
190        adapt_to_json_schema_subset(&mut json).unwrap();
191
192        assert_eq!(
193            json,
194            json!({
195                "description": "A test field",
196                "anyOf": [
197                    { "type": "string" },
198                    { "type": "integer" }
199                ]
200            })
201        );
202    }
203
204    #[test]
205    fn test_transform_nested_objects() {
206        let mut json = json!({
207            "type": "object",
208            "properties": {
209                "nested": {
210                    "oneOf": [
211                        { "type": "string" },
212                        { "type": "null" }
213                    ],
214                    "format": "email"
215                }
216            }
217        });
218
219        adapt_to_json_schema_subset(&mut json).unwrap();
220
221        assert_eq!(
222            json,
223            json!({
224                "type": "object",
225                "properties": {
226                    "nested": {
227                        "anyOf": [
228                            { "type": "string" },
229                            { "type": "null" }
230                        ]
231                    }
232                }
233            })
234        );
235    }
236
237    #[test]
238    fn test_transform_fails_if_unsupported_keys_exist() {
239        let mut json = json!({
240            "type": "object",
241            "properties": {
242                "$ref": "#/definitions/User",
243            }
244        });
245
246        assert!(adapt_to_json_schema_subset(&mut json).is_err());
247
248        let mut json = json!({
249            "type": "object",
250            "properties": {
251                "if": "...",
252            }
253        });
254
255        assert!(adapt_to_json_schema_subset(&mut json).is_err());
256
257        let mut json = json!({
258            "type": "object",
259            "properties": {
260                "then": "...",
261            }
262        });
263
264        assert!(adapt_to_json_schema_subset(&mut json).is_err());
265
266        let mut json = json!({
267            "type": "object",
268            "properties": {
269                "else": "...",
270            }
271        });
272
273        assert!(adapt_to_json_schema_subset(&mut json).is_err());
274    }
275
276    #[test]
277    fn test_preprocess_json_schema_adds_additional_properties() {
278        let mut json = json!({
279            "type": "object",
280            "properties": {
281                "name": {
282                    "type": "string"
283                }
284            }
285        });
286
287        preprocess_json_schema(&mut json).unwrap();
288
289        assert_eq!(
290            json,
291            json!({
292                "type": "object",
293                "properties": {
294                    "name": {
295                        "type": "string"
296                    }
297                },
298                "additionalProperties": false
299            })
300        );
301    }
302
303    #[test]
304    fn test_preprocess_json_schema_preserves_additional_properties() {
305        let mut json = json!({
306            "type": "object",
307            "properties": {
308                "name": {
309                    "type": "string"
310                }
311            },
312            "additionalProperties": true
313        });
314
315        preprocess_json_schema(&mut json).unwrap();
316
317        assert_eq!(
318            json,
319            json!({
320                "type": "object",
321                "properties": {
322                    "name": {
323                        "type": "string"
324                    }
325                },
326                "additionalProperties": true
327            })
328        );
329    }
330}