tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 26    if let Value::Object(obj) = json {
 27        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 28
 29        for key in UNSUPPORTED_KEYS {
 30            if obj.contains_key(key) {
 31                return Err(anyhow::anyhow!(
 32                    "Schema cannot be made compatible because it contains \"{}\" ",
 33                    key
 34                ));
 35            }
 36        }
 37
 38        const KEYS_TO_REMOVE: [&str; 5] = [
 39            "format",
 40            "additionalProperties",
 41            "exclusiveMinimum",
 42            "exclusiveMaximum",
 43            "optional",
 44        ];
 45        for key in KEYS_TO_REMOVE {
 46            obj.remove(key);
 47        }
 48
 49        // If a type is not specified for an input parameter, add a default type
 50        if matches!(obj.get("description"), Some(Value::String(_)))
 51            && !obj.contains_key("type")
 52            && !(obj.contains_key("anyOf")
 53                || obj.contains_key("oneOf")
 54                || obj.contains_key("allOf"))
 55        {
 56            obj.insert("type".to_string(), Value::String("string".to_string()));
 57        }
 58
 59        // Handle oneOf -> anyOf conversion
 60        if let Some(subschemas) = obj.get_mut("oneOf") {
 61            if subschemas.is_array() {
 62                let subschemas_clone = subschemas.clone();
 63                obj.remove("oneOf");
 64                obj.insert("anyOf".to_string(), subschemas_clone);
 65            }
 66        }
 67
 68        // Recursively process all nested objects and arrays
 69        for (_, value) in obj.iter_mut() {
 70            if let Value::Object(_) | Value::Array(_) = value {
 71                adapt_to_json_schema_subset(value)?;
 72            }
 73        }
 74    } else if let Value::Array(arr) = json {
 75        for item in arr.iter_mut() {
 76            adapt_to_json_schema_subset(item)?;
 77        }
 78    }
 79    Ok(())
 80}
 81
 82#[cfg(test)]
 83mod tests {
 84    use super::*;
 85    use serde_json::json;
 86
 87    #[test]
 88    fn test_transform_adds_type_when_missing() {
 89        let mut json = json!({
 90            "description": "A test field without type"
 91        });
 92
 93        adapt_to_json_schema_subset(&mut json).unwrap();
 94
 95        assert_eq!(
 96            json,
 97            json!({
 98                "description": "A test field without type",
 99                "type": "string"
100            })
101        );
102
103        // Ensure that we do not add a type if it is an object
104        let mut json = json!({
105            "description": {
106                "value": "abc",
107                "type": "string"
108            }
109        });
110
111        adapt_to_json_schema_subset(&mut json).unwrap();
112
113        assert_eq!(
114            json,
115            json!({
116                "description": {
117                    "value": "abc",
118                    "type": "string"
119                }
120            })
121        );
122    }
123
124    #[test]
125    fn test_transform_removes_unsupported_keys() {
126        let mut json = json!({
127            "description": "A test field",
128            "type": "integer",
129            "format": "uint32",
130            "exclusiveMinimum": 0,
131            "exclusiveMaximum": 100,
132            "additionalProperties": false,
133            "optional": true
134        });
135
136        adapt_to_json_schema_subset(&mut json).unwrap();
137
138        assert_eq!(
139            json,
140            json!({
141                "description": "A test field",
142                "type": "integer"
143            })
144        );
145    }
146
147    #[test]
148    fn test_transform_one_of_to_any_of() {
149        let mut json = json!({
150            "description": "A test field",
151            "oneOf": [
152                { "type": "string" },
153                { "type": "integer" }
154            ]
155        });
156
157        adapt_to_json_schema_subset(&mut json).unwrap();
158
159        assert_eq!(
160            json,
161            json!({
162                "description": "A test field",
163                "anyOf": [
164                    { "type": "string" },
165                    { "type": "integer" }
166                ]
167            })
168        );
169    }
170
171    #[test]
172    fn test_transform_nested_objects() {
173        let mut json = json!({
174            "type": "object",
175            "properties": {
176                "nested": {
177                    "oneOf": [
178                        { "type": "string" },
179                        { "type": "null" }
180                    ],
181                    "format": "email"
182                }
183            }
184        });
185
186        adapt_to_json_schema_subset(&mut json).unwrap();
187
188        assert_eq!(
189            json,
190            json!({
191                "type": "object",
192                "properties": {
193                    "nested": {
194                        "anyOf": [
195                            { "type": "string" },
196                            { "type": "null" }
197                        ]
198                    }
199                }
200            })
201        );
202    }
203
204    #[test]
205    fn test_transform_fails_if_unsupported_keys_exist() {
206        let mut json = json!({
207            "type": "object",
208            "properties": {
209                "$ref": "#/definitions/User",
210            }
211        });
212
213        assert!(adapt_to_json_schema_subset(&mut json).is_err());
214
215        let mut json = json!({
216            "type": "object",
217            "properties": {
218                "if": "...",
219            }
220        });
221
222        assert!(adapt_to_json_schema_subset(&mut json).is_err());
223
224        let mut json = json!({
225            "type": "object",
226            "properties": {
227                "then": "...",
228            }
229        });
230
231        assert!(adapt_to_json_schema_subset(&mut json).is_err());
232
233        let mut json = json!({
234            "type": "object",
235            "properties": {
236                "else": "...",
237            }
238        });
239
240        assert!(adapt_to_json_schema_subset(&mut json).is_err());
241    }
242}