tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 25    // `additionalProperties` defaults to `false` unless explicitly specified.
 26    // This prevents models from hallucinating tool parameters.
 27    if let Value::Object(obj) = json
 28        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 29    {
 30        if !obj.contains_key("additionalProperties") {
 31            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 32        }
 33
 34        // OpenAI API requires non-missing `properties`
 35        if !obj.contains_key("properties") {
 36            obj.insert("properties".to_string(), Value::Object(Default::default()));
 37        }
 38    }
 39    Ok(())
 40}
 41
 42/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 43fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 44    if let Value::Object(obj) = json {
 45        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 46
 47        for key in UNSUPPORTED_KEYS {
 48            anyhow::ensure!(
 49                !obj.contains_key(key),
 50                "Schema cannot be made compatible because it contains \"{key}\""
 51            );
 52        }
 53
 54        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
 55            ("format", |value| value.is_string()),
 56            ("additionalProperties", |value| value.is_boolean()),
 57            ("exclusiveMinimum", |value| value.is_number()),
 58            ("exclusiveMaximum", |value| value.is_number()),
 59            ("optional", |value| value.is_boolean()),
 60        ];
 61        for (key, predicate) in KEYS_TO_REMOVE {
 62            if let Some(value) = obj.get(key)
 63                && predicate(value)
 64            {
 65                obj.remove(key);
 66            }
 67        }
 68
 69        // If a type is not specified for an input parameter, add a default type
 70        if matches!(obj.get("description"), Some(Value::String(_)))
 71            && !obj.contains_key("type")
 72            && !(obj.contains_key("anyOf")
 73                || obj.contains_key("oneOf")
 74                || obj.contains_key("allOf"))
 75        {
 76            obj.insert("type".to_string(), Value::String("string".to_string()));
 77        }
 78
 79        // Handle oneOf -> anyOf conversion
 80        if let Some(subschemas) = obj.get_mut("oneOf")
 81            && subschemas.is_array()
 82        {
 83            let subschemas_clone = subschemas.clone();
 84            obj.remove("oneOf");
 85            obj.insert("anyOf".to_string(), subschemas_clone);
 86        }
 87
 88        // Recursively process all nested objects and arrays
 89        for (_, value) in obj.iter_mut() {
 90            if let Value::Object(_) | Value::Array(_) = value {
 91                adapt_to_json_schema_subset(value)?;
 92            }
 93        }
 94    } else if let Value::Array(arr) = json {
 95        for item in arr.iter_mut() {
 96            adapt_to_json_schema_subset(item)?;
 97        }
 98    }
 99    Ok(())
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use serde_json::json;
106
107    #[test]
108    fn test_transform_adds_type_when_missing() {
109        let mut json = json!({
110            "description": "A test field without type"
111        });
112
113        adapt_to_json_schema_subset(&mut json).unwrap();
114
115        assert_eq!(
116            json,
117            json!({
118                "description": "A test field without type",
119                "type": "string"
120            })
121        );
122
123        // Ensure that we do not add a type if it is an object
124        let mut json = json!({
125            "description": {
126                "value": "abc",
127                "type": "string"
128            }
129        });
130
131        adapt_to_json_schema_subset(&mut json).unwrap();
132
133        assert_eq!(
134            json,
135            json!({
136                "description": {
137                    "value": "abc",
138                    "type": "string"
139                }
140            })
141        );
142    }
143
144    #[test]
145    fn test_transform_removes_unsupported_keys() {
146        let mut json = json!({
147            "description": "A test field",
148            "type": "integer",
149            "format": "uint32",
150            "exclusiveMinimum": 0,
151            "exclusiveMaximum": 100,
152            "additionalProperties": false,
153            "optional": true
154        });
155
156        adapt_to_json_schema_subset(&mut json).unwrap();
157
158        assert_eq!(
159            json,
160            json!({
161                "description": "A test field",
162                "type": "integer"
163            })
164        );
165
166        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
167        let mut json = json!({
168            "description": "A test field",
169            "type": "integer",
170            "format": {},
171        });
172
173        adapt_to_json_schema_subset(&mut json).unwrap();
174
175        assert_eq!(
176            json,
177            json!({
178                "description": "A test field",
179                "type": "integer",
180                "format": {},
181            })
182        );
183    }
184
185    #[test]
186    fn test_transform_one_of_to_any_of() {
187        let mut json = json!({
188            "description": "A test field",
189            "oneOf": [
190                { "type": "string" },
191                { "type": "integer" }
192            ]
193        });
194
195        adapt_to_json_schema_subset(&mut json).unwrap();
196
197        assert_eq!(
198            json,
199            json!({
200                "description": "A test field",
201                "anyOf": [
202                    { "type": "string" },
203                    { "type": "integer" }
204                ]
205            })
206        );
207    }
208
209    #[test]
210    fn test_transform_nested_objects() {
211        let mut json = json!({
212            "type": "object",
213            "properties": {
214                "nested": {
215                    "oneOf": [
216                        { "type": "string" },
217                        { "type": "null" }
218                    ],
219                    "format": "email"
220                }
221            }
222        });
223
224        adapt_to_json_schema_subset(&mut json).unwrap();
225
226        assert_eq!(
227            json,
228            json!({
229                "type": "object",
230                "properties": {
231                    "nested": {
232                        "anyOf": [
233                            { "type": "string" },
234                            { "type": "null" }
235                        ]
236                    }
237                }
238            })
239        );
240    }
241
242    #[test]
243    fn test_transform_fails_if_unsupported_keys_exist() {
244        let mut json = json!({
245            "type": "object",
246            "properties": {
247                "$ref": "#/definitions/User",
248            }
249        });
250
251        assert!(adapt_to_json_schema_subset(&mut json).is_err());
252
253        let mut json = json!({
254            "type": "object",
255            "properties": {
256                "if": "...",
257            }
258        });
259
260        assert!(adapt_to_json_schema_subset(&mut json).is_err());
261
262        let mut json = json!({
263            "type": "object",
264            "properties": {
265                "then": "...",
266            }
267        });
268
269        assert!(adapt_to_json_schema_subset(&mut json).is_err());
270
271        let mut json = json!({
272            "type": "object",
273            "properties": {
274                "else": "...",
275            }
276        });
277
278        assert!(adapt_to_json_schema_subset(&mut json).is_err());
279    }
280
281    #[test]
282    fn test_preprocess_json_schema_adds_additional_properties() {
283        let mut json = json!({
284            "type": "object",
285            "properties": {
286                "name": {
287                    "type": "string"
288                }
289            }
290        });
291
292        preprocess_json_schema(&mut json).unwrap();
293
294        assert_eq!(
295            json,
296            json!({
297                "type": "object",
298                "properties": {
299                    "name": {
300                        "type": "string"
301                    }
302                },
303                "additionalProperties": false
304            })
305        );
306    }
307
308    #[test]
309    fn test_preprocess_json_schema_preserves_additional_properties() {
310        let mut json = json!({
311            "type": "object",
312            "properties": {
313                "name": {
314                    "type": "string"
315                }
316            },
317            "additionalProperties": true
318        });
319
320        preprocess_json_schema(&mut json).unwrap();
321
322        assert_eq!(
323            json,
324            json!({
325                "type": "object",
326                "properties": {
327                    "name": {
328                        "type": "string"
329                    }
330                },
331                "additionalProperties": true
332            })
333        );
334    }
335}