tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 25    // `additionalProperties` defaults to `false` unless explicitly specified.
 26    // This prevents models from hallucinating tool parameters.
 27    if let Value::Object(obj) = json
 28        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object") {
 29            if !obj.contains_key("additionalProperties") {
 30                obj.insert("additionalProperties".to_string(), Value::Bool(false));
 31            }
 32
 33            // OpenAI API requires non-missing `properties`
 34            if !obj.contains_key("properties") {
 35                obj.insert("properties".to_string(), Value::Object(Default::default()));
 36            }
 37        }
 38    Ok(())
 39}
 40
 41/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 42fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 43    if let Value::Object(obj) = json {
 44        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 45
 46        for key in UNSUPPORTED_KEYS {
 47            anyhow::ensure!(
 48                !obj.contains_key(key),
 49                "Schema cannot be made compatible because it contains \"{key}\""
 50            );
 51        }
 52
 53        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
 54            ("format", |value| value.is_string()),
 55            ("additionalProperties", |value| value.is_boolean()),
 56            ("exclusiveMinimum", |value| value.is_number()),
 57            ("exclusiveMaximum", |value| value.is_number()),
 58            ("optional", |value| value.is_boolean()),
 59        ];
 60        for (key, predicate) in KEYS_TO_REMOVE {
 61            if let Some(value) = obj.get(key)
 62                && predicate(value) {
 63                    obj.remove(key);
 64                }
 65        }
 66
 67        // If a type is not specified for an input parameter, add a default type
 68        if matches!(obj.get("description"), Some(Value::String(_)))
 69            && !obj.contains_key("type")
 70            && !(obj.contains_key("anyOf")
 71                || obj.contains_key("oneOf")
 72                || obj.contains_key("allOf"))
 73        {
 74            obj.insert("type".to_string(), Value::String("string".to_string()));
 75        }
 76
 77        // Handle oneOf -> anyOf conversion
 78        if let Some(subschemas) = obj.get_mut("oneOf")
 79            && subschemas.is_array() {
 80                let subschemas_clone = subschemas.clone();
 81                obj.remove("oneOf");
 82                obj.insert("anyOf".to_string(), subschemas_clone);
 83            }
 84
 85        // Recursively process all nested objects and arrays
 86        for (_, value) in obj.iter_mut() {
 87            if let Value::Object(_) | Value::Array(_) = value {
 88                adapt_to_json_schema_subset(value)?;
 89            }
 90        }
 91    } else if let Value::Array(arr) = json {
 92        for item in arr.iter_mut() {
 93            adapt_to_json_schema_subset(item)?;
 94        }
 95    }
 96    Ok(())
 97}
 98
 99#[cfg(test)]
100mod tests {
101    use super::*;
102    use serde_json::json;
103
104    #[test]
105    fn test_transform_adds_type_when_missing() {
106        let mut json = json!({
107            "description": "A test field without type"
108        });
109
110        adapt_to_json_schema_subset(&mut json).unwrap();
111
112        assert_eq!(
113            json,
114            json!({
115                "description": "A test field without type",
116                "type": "string"
117            })
118        );
119
120        // Ensure that we do not add a type if it is an object
121        let mut json = json!({
122            "description": {
123                "value": "abc",
124                "type": "string"
125            }
126        });
127
128        adapt_to_json_schema_subset(&mut json).unwrap();
129
130        assert_eq!(
131            json,
132            json!({
133                "description": {
134                    "value": "abc",
135                    "type": "string"
136                }
137            })
138        );
139    }
140
141    #[test]
142    fn test_transform_removes_unsupported_keys() {
143        let mut json = json!({
144            "description": "A test field",
145            "type": "integer",
146            "format": "uint32",
147            "exclusiveMinimum": 0,
148            "exclusiveMaximum": 100,
149            "additionalProperties": false,
150            "optional": true
151        });
152
153        adapt_to_json_schema_subset(&mut json).unwrap();
154
155        assert_eq!(
156            json,
157            json!({
158                "description": "A test field",
159                "type": "integer"
160            })
161        );
162
163        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
164        let mut json = json!({
165            "description": "A test field",
166            "type": "integer",
167            "format": {},
168        });
169
170        adapt_to_json_schema_subset(&mut json).unwrap();
171
172        assert_eq!(
173            json,
174            json!({
175                "description": "A test field",
176                "type": "integer",
177                "format": {},
178            })
179        );
180    }
181
182    #[test]
183    fn test_transform_one_of_to_any_of() {
184        let mut json = json!({
185            "description": "A test field",
186            "oneOf": [
187                { "type": "string" },
188                { "type": "integer" }
189            ]
190        });
191
192        adapt_to_json_schema_subset(&mut json).unwrap();
193
194        assert_eq!(
195            json,
196            json!({
197                "description": "A test field",
198                "anyOf": [
199                    { "type": "string" },
200                    { "type": "integer" }
201                ]
202            })
203        );
204    }
205
206    #[test]
207    fn test_transform_nested_objects() {
208        let mut json = json!({
209            "type": "object",
210            "properties": {
211                "nested": {
212                    "oneOf": [
213                        { "type": "string" },
214                        { "type": "null" }
215                    ],
216                    "format": "email"
217                }
218            }
219        });
220
221        adapt_to_json_schema_subset(&mut json).unwrap();
222
223        assert_eq!(
224            json,
225            json!({
226                "type": "object",
227                "properties": {
228                    "nested": {
229                        "anyOf": [
230                            { "type": "string" },
231                            { "type": "null" }
232                        ]
233                    }
234                }
235            })
236        );
237    }
238
239    #[test]
240    fn test_transform_fails_if_unsupported_keys_exist() {
241        let mut json = json!({
242            "type": "object",
243            "properties": {
244                "$ref": "#/definitions/User",
245            }
246        });
247
248        assert!(adapt_to_json_schema_subset(&mut json).is_err());
249
250        let mut json = json!({
251            "type": "object",
252            "properties": {
253                "if": "...",
254            }
255        });
256
257        assert!(adapt_to_json_schema_subset(&mut json).is_err());
258
259        let mut json = json!({
260            "type": "object",
261            "properties": {
262                "then": "...",
263            }
264        });
265
266        assert!(adapt_to_json_schema_subset(&mut json).is_err());
267
268        let mut json = json!({
269            "type": "object",
270            "properties": {
271                "else": "...",
272            }
273        });
274
275        assert!(adapt_to_json_schema_subset(&mut json).is_err());
276    }
277
278    #[test]
279    fn test_preprocess_json_schema_adds_additional_properties() {
280        let mut json = json!({
281            "type": "object",
282            "properties": {
283                "name": {
284                    "type": "string"
285                }
286            }
287        });
288
289        preprocess_json_schema(&mut json).unwrap();
290
291        assert_eq!(
292            json,
293            json!({
294                "type": "object",
295                "properties": {
296                    "name": {
297                        "type": "string"
298                    }
299                },
300                "additionalProperties": false
301            })
302        );
303    }
304
305    #[test]
306    fn test_preprocess_json_schema_preserves_additional_properties() {
307        let mut json = json!({
308            "type": "object",
309            "properties": {
310                "name": {
311                    "type": "string"
312                }
313            },
314            "additionalProperties": true
315        });
316
317        preprocess_json_schema(&mut json).unwrap();
318
319        assert_eq!(
320            json,
321            json!({
322                "type": "object",
323                "properties": {
324                    "name": {
325                        "type": "string"
326                    }
327                },
328                "additionalProperties": true
329            })
330        );
331    }
332}