tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 26    if let Value::Object(obj) = json {
 27        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 28
 29        for key in UNSUPPORTED_KEYS {
 30            if obj.contains_key(key) {
 31                return Err(anyhow::anyhow!(
 32                    "Schema cannot be made compatible because it contains \"{}\" ",
 33                    key
 34                ));
 35            }
 36        }
 37
 38        const KEYS_TO_REMOVE: [&str; 4] = [
 39            "format",
 40            "additionalProperties",
 41            "exclusiveMinimum",
 42            "exclusiveMaximum",
 43        ];
 44        for key in KEYS_TO_REMOVE {
 45            obj.remove(key);
 46        }
 47
 48        if let Some(default) = obj.get("default") {
 49            let is_null = default.is_null();
 50            // Default is not supported, so we need to remove it
 51            obj.remove("default");
 52            if is_null {
 53                obj.insert("nullable".to_string(), Value::Bool(true));
 54            }
 55        }
 56
 57        // If a type is not specified for an input parameter, add a default type
 58        if matches!(obj.get("description"), Some(Value::String(_)))
 59            && !obj.contains_key("type")
 60            && !(obj.contains_key("anyOf")
 61                || obj.contains_key("oneOf")
 62                || obj.contains_key("allOf"))
 63        {
 64            obj.insert("type".to_string(), Value::String("string".to_string()));
 65        }
 66
 67        // Handle oneOf -> anyOf conversion
 68        if let Some(subschemas) = obj.get_mut("oneOf") {
 69            if subschemas.is_array() {
 70                let subschemas_clone = subschemas.clone();
 71                obj.remove("oneOf");
 72                obj.insert("anyOf".to_string(), subschemas_clone);
 73            }
 74        }
 75
 76        // Recursively process all nested objects and arrays
 77        for (_, value) in obj.iter_mut() {
 78            if let Value::Object(_) | Value::Array(_) = value {
 79                adapt_to_json_schema_subset(value)?;
 80            }
 81        }
 82    } else if let Value::Array(arr) = json {
 83        for item in arr.iter_mut() {
 84            adapt_to_json_schema_subset(item)?;
 85        }
 86    }
 87    Ok(())
 88}
 89
 90#[cfg(test)]
 91mod tests {
 92    use super::*;
 93    use serde_json::json;
 94
 95    #[test]
 96    fn test_transform_default_null_to_nullable() {
 97        let mut json = json!({
 98            "description": "A test field",
 99            "type": "string",
100            "default": null
101        });
102
103        adapt_to_json_schema_subset(&mut json).unwrap();
104
105        assert_eq!(
106            json,
107            json!({
108                "description": "A test field",
109                "type": "string",
110                "nullable": true
111            })
112        );
113    }
114
115    #[test]
116    fn test_transform_adds_type_when_missing() {
117        let mut json = json!({
118            "description": "A test field without type"
119        });
120
121        adapt_to_json_schema_subset(&mut json).unwrap();
122
123        assert_eq!(
124            json,
125            json!({
126                "description": "A test field without type",
127                "type": "string"
128            })
129        );
130
131        // Ensure that we do not add a type if it is an object
132        let mut json = json!({
133            "description": {
134                "value": "abc",
135                "type": "string"
136            }
137        });
138
139        adapt_to_json_schema_subset(&mut json).unwrap();
140
141        assert_eq!(
142            json,
143            json!({
144                "description": {
145                    "value": "abc",
146                    "type": "string"
147                }
148            })
149        );
150    }
151
152    #[test]
153    fn test_transform_removes_unsupported_keys() {
154        let mut json = json!({
155            "description": "A test field",
156            "type": "integer",
157            "format": "uint32",
158            "exclusiveMinimum": 0,
159            "exclusiveMaximum": 100,
160            "additionalProperties": false
161        });
162
163        adapt_to_json_schema_subset(&mut json).unwrap();
164
165        assert_eq!(
166            json,
167            json!({
168                "description": "A test field",
169                "type": "integer"
170            })
171        );
172    }
173
174    #[test]
175    fn test_transform_one_of_to_any_of() {
176        let mut json = json!({
177            "description": "A test field",
178            "oneOf": [
179                { "type": "string" },
180                { "type": "integer" }
181            ]
182        });
183
184        adapt_to_json_schema_subset(&mut json).unwrap();
185
186        assert_eq!(
187            json,
188            json!({
189                "description": "A test field",
190                "anyOf": [
191                    { "type": "string" },
192                    { "type": "integer" }
193                ]
194            })
195        );
196    }
197
198    #[test]
199    fn test_transform_nested_objects() {
200        let mut json = json!({
201            "type": "object",
202            "properties": {
203                "nested": {
204                    "oneOf": [
205                        { "type": "string" },
206                        { "type": "null" }
207                    ],
208                    "format": "email"
209                }
210            }
211        });
212
213        adapt_to_json_schema_subset(&mut json).unwrap();
214
215        assert_eq!(
216            json,
217            json!({
218                "type": "object",
219                "properties": {
220                    "nested": {
221                        "anyOf": [
222                            { "type": "string" },
223                            { "type": "null" }
224                        ]
225                    }
226                }
227            })
228        );
229    }
230
231    #[test]
232    fn test_transform_fails_if_unsupported_keys_exist() {
233        let mut json = json!({
234            "type": "object",
235            "properties": {
236                "$ref": "#/definitions/User",
237            }
238        });
239
240        assert!(adapt_to_json_schema_subset(&mut json).is_err());
241
242        let mut json = json!({
243            "type": "object",
244            "properties": {
245                "if": "...",
246            }
247        });
248
249        assert!(adapt_to_json_schema_subset(&mut json).is_err());
250
251        let mut json = json!({
252            "type": "object",
253            "properties": {
254                "then": "...",
255            }
256        });
257
258        assert!(adapt_to_json_schema_subset(&mut json).is_err());
259
260        let mut json = json!({
261            "type": "object",
262            "properties": {
263                "else": "...",
264            }
265        });
266
267        assert!(adapt_to_json_schema_subset(&mut json).is_err());
268    }
269}