tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 25    // `additionalProperties` defaults to `false` unless explicitly specified.
 26    // This prevents models from hallucinating tool parameters.
 27    if let Value::Object(obj) = json {
 28        if let Some(Value::String(type_str)) = obj.get("type") {
 29            if type_str == "object" && !obj.contains_key("additionalProperties") {
 30                obj.insert("additionalProperties".to_string(), Value::Bool(false));
 31            }
 32        }
 33    }
 34    Ok(())
 35}
 36
 37/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 38fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 39    if let Value::Object(obj) = json {
 40        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 41
 42        for key in UNSUPPORTED_KEYS {
 43            anyhow::ensure!(
 44                !obj.contains_key(key),
 45                "Schema cannot be made compatible because it contains \"{key}\""
 46            );
 47        }
 48
 49        const KEYS_TO_REMOVE: [&str; 5] = [
 50            "format",
 51            "additionalProperties",
 52            "exclusiveMinimum",
 53            "exclusiveMaximum",
 54            "optional",
 55        ];
 56        for key in KEYS_TO_REMOVE {
 57            obj.remove(key);
 58        }
 59
 60        // If a type is not specified for an input parameter, add a default type
 61        if matches!(obj.get("description"), Some(Value::String(_)))
 62            && !obj.contains_key("type")
 63            && !(obj.contains_key("anyOf")
 64                || obj.contains_key("oneOf")
 65                || obj.contains_key("allOf"))
 66        {
 67            obj.insert("type".to_string(), Value::String("string".to_string()));
 68        }
 69
 70        // Handle oneOf -> anyOf conversion
 71        if let Some(subschemas) = obj.get_mut("oneOf") {
 72            if subschemas.is_array() {
 73                let subschemas_clone = subschemas.clone();
 74                obj.remove("oneOf");
 75                obj.insert("anyOf".to_string(), subschemas_clone);
 76            }
 77        }
 78
 79        // Recursively process all nested objects and arrays
 80        for (_, value) in obj.iter_mut() {
 81            if let Value::Object(_) | Value::Array(_) = value {
 82                adapt_to_json_schema_subset(value)?;
 83            }
 84        }
 85    } else if let Value::Array(arr) = json {
 86        for item in arr.iter_mut() {
 87            adapt_to_json_schema_subset(item)?;
 88        }
 89    }
 90    Ok(())
 91}
 92
 93#[cfg(test)]
 94mod tests {
 95    use super::*;
 96    use serde_json::json;
 97
 98    #[test]
 99    fn test_transform_adds_type_when_missing() {
100        let mut json = json!({
101            "description": "A test field without type"
102        });
103
104        adapt_to_json_schema_subset(&mut json).unwrap();
105
106        assert_eq!(
107            json,
108            json!({
109                "description": "A test field without type",
110                "type": "string"
111            })
112        );
113
114        // Ensure that we do not add a type if it is an object
115        let mut json = json!({
116            "description": {
117                "value": "abc",
118                "type": "string"
119            }
120        });
121
122        adapt_to_json_schema_subset(&mut json).unwrap();
123
124        assert_eq!(
125            json,
126            json!({
127                "description": {
128                    "value": "abc",
129                    "type": "string"
130                }
131            })
132        );
133    }
134
135    #[test]
136    fn test_transform_removes_unsupported_keys() {
137        let mut json = json!({
138            "description": "A test field",
139            "type": "integer",
140            "format": "uint32",
141            "exclusiveMinimum": 0,
142            "exclusiveMaximum": 100,
143            "additionalProperties": false,
144            "optional": true
145        });
146
147        adapt_to_json_schema_subset(&mut json).unwrap();
148
149        assert_eq!(
150            json,
151            json!({
152                "description": "A test field",
153                "type": "integer"
154            })
155        );
156    }
157
158    #[test]
159    fn test_transform_one_of_to_any_of() {
160        let mut json = json!({
161            "description": "A test field",
162            "oneOf": [
163                { "type": "string" },
164                { "type": "integer" }
165            ]
166        });
167
168        adapt_to_json_schema_subset(&mut json).unwrap();
169
170        assert_eq!(
171            json,
172            json!({
173                "description": "A test field",
174                "anyOf": [
175                    { "type": "string" },
176                    { "type": "integer" }
177                ]
178            })
179        );
180    }
181
182    #[test]
183    fn test_transform_nested_objects() {
184        let mut json = json!({
185            "type": "object",
186            "properties": {
187                "nested": {
188                    "oneOf": [
189                        { "type": "string" },
190                        { "type": "null" }
191                    ],
192                    "format": "email"
193                }
194            }
195        });
196
197        adapt_to_json_schema_subset(&mut json).unwrap();
198
199        assert_eq!(
200            json,
201            json!({
202                "type": "object",
203                "properties": {
204                    "nested": {
205                        "anyOf": [
206                            { "type": "string" },
207                            { "type": "null" }
208                        ]
209                    }
210                }
211            })
212        );
213    }
214
215    #[test]
216    fn test_transform_fails_if_unsupported_keys_exist() {
217        let mut json = json!({
218            "type": "object",
219            "properties": {
220                "$ref": "#/definitions/User",
221            }
222        });
223
224        assert!(adapt_to_json_schema_subset(&mut json).is_err());
225
226        let mut json = json!({
227            "type": "object",
228            "properties": {
229                "if": "...",
230            }
231        });
232
233        assert!(adapt_to_json_schema_subset(&mut json).is_err());
234
235        let mut json = json!({
236            "type": "object",
237            "properties": {
238                "then": "...",
239            }
240        });
241
242        assert!(adapt_to_json_schema_subset(&mut json).is_err());
243
244        let mut json = json!({
245            "type": "object",
246            "properties": {
247                "else": "...",
248            }
249        });
250
251        assert!(adapt_to_json_schema_subset(&mut json).is_err());
252    }
253
254    #[test]
255    fn test_preprocess_json_schema_adds_additional_properties() {
256        let mut json = json!({
257            "type": "object",
258            "properties": {
259                "name": {
260                    "type": "string"
261                }
262            }
263        });
264
265        preprocess_json_schema(&mut json).unwrap();
266
267        assert_eq!(
268            json,
269            json!({
270                "type": "object",
271                "properties": {
272                    "name": {
273                        "type": "string"
274                    }
275                },
276                "additionalProperties": false
277            })
278        );
279    }
280
281    #[test]
282    fn test_preprocess_json_schema_preserves_additional_properties() {
283        let mut json = json!({
284            "type": "object",
285            "properties": {
286                "name": {
287                    "type": "string"
288                }
289            },
290            "additionalProperties": true
291        });
292
293        preprocess_json_schema(&mut json).unwrap();
294
295        assert_eq!(
296            json,
297            json!({
298                "type": "object",
299                "properties": {
300                    "name": {
301                        "type": "string"
302                    }
303                },
304                "additionalProperties": true
305            })
306        );
307    }
308}