tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 26    if let Value::Object(obj) = json {
 27        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 28
 29        for key in UNSUPPORTED_KEYS {
 30            if obj.contains_key(key) {
 31                return Err(anyhow::anyhow!(
 32                    "Schema cannot be made compatible because it contains \"{}\" ",
 33                    key
 34                ));
 35            }
 36        }
 37
 38        obj.remove("format");
 39
 40        if let Some(default) = obj.get("default") {
 41            let is_null = default.is_null();
 42            // Default is not supported, so we need to remove it
 43            obj.remove("default");
 44            if is_null {
 45                obj.insert("nullable".to_string(), Value::Bool(true));
 46            }
 47        }
 48
 49        // If a type is not specified for an input parameter, add a default type
 50        if obj.contains_key("description")
 51            && !obj.contains_key("type")
 52            && !(obj.contains_key("anyOf")
 53                || obj.contains_key("oneOf")
 54                || obj.contains_key("allOf"))
 55        {
 56            obj.insert("type".to_string(), Value::String("string".to_string()));
 57        }
 58
 59        // Handle oneOf -> anyOf conversion
 60        if let Some(subschemas) = obj.get_mut("oneOf") {
 61            if subschemas.is_array() {
 62                let subschemas_clone = subschemas.clone();
 63                obj.remove("oneOf");
 64                obj.insert("anyOf".to_string(), subschemas_clone);
 65            }
 66        }
 67
 68        // Recursively process all nested objects and arrays
 69        for (_, value) in obj.iter_mut() {
 70            if let Value::Object(_) | Value::Array(_) = value {
 71                adapt_to_json_schema_subset(value)?;
 72            }
 73        }
 74    } else if let Value::Array(arr) = json {
 75        for item in arr.iter_mut() {
 76            adapt_to_json_schema_subset(item)?;
 77        }
 78    }
 79    Ok(())
 80}
 81
 82#[cfg(test)]
 83mod tests {
 84    use super::*;
 85    use serde_json::json;
 86
 87    #[test]
 88    fn test_transform_default_null_to_nullable() {
 89        let mut json = json!({
 90            "description": "A test field",
 91            "type": "string",
 92            "default": null
 93        });
 94
 95        adapt_to_json_schema_subset(&mut json).unwrap();
 96
 97        assert_eq!(
 98            json,
 99            json!({
100                "description": "A test field",
101                "type": "string",
102                "nullable": true
103            })
104        );
105    }
106
107    #[test]
108    fn test_transform_adds_type_when_missing() {
109        let mut json = json!({
110            "description": "A test field without type"
111        });
112
113        adapt_to_json_schema_subset(&mut json).unwrap();
114
115        assert_eq!(
116            json,
117            json!({
118                "description": "A test field without type",
119                "type": "string"
120            })
121        );
122    }
123
124    #[test]
125    fn test_transform_removes_format() {
126        let mut json = json!({
127            "description": "A test field",
128            "type": "integer",
129            "format": "uint32"
130        });
131
132        adapt_to_json_schema_subset(&mut json).unwrap();
133
134        assert_eq!(
135            json,
136            json!({
137                "description": "A test field",
138                "type": "integer"
139            })
140        );
141    }
142
143    #[test]
144    fn test_transform_one_of_to_any_of() {
145        let mut json = json!({
146            "description": "A test field",
147            "oneOf": [
148                { "type": "string" },
149                { "type": "integer" }
150            ]
151        });
152
153        adapt_to_json_schema_subset(&mut json).unwrap();
154
155        assert_eq!(
156            json,
157            json!({
158                "description": "A test field",
159                "anyOf": [
160                    { "type": "string" },
161                    { "type": "integer" }
162                ]
163            })
164        );
165    }
166
167    #[test]
168    fn test_transform_nested_objects() {
169        let mut json = json!({
170            "type": "object",
171            "properties": {
172                "nested": {
173                    "oneOf": [
174                        { "type": "string" },
175                        { "type": "null" }
176                    ],
177                    "format": "email"
178                }
179            }
180        });
181
182        adapt_to_json_schema_subset(&mut json).unwrap();
183
184        assert_eq!(
185            json,
186            json!({
187                "type": "object",
188                "properties": {
189                    "nested": {
190                        "anyOf": [
191                            { "type": "string" },
192                            { "type": "null" }
193                        ]
194                    }
195                }
196            })
197        );
198    }
199
200    #[test]
201    fn test_transform_fails_if_unsupported_keys_exist() {
202        let mut json = json!({
203            "type": "object",
204            "properties": {
205                "$ref": "#/definitions/User",
206            }
207        });
208
209        assert!(adapt_to_json_schema_subset(&mut json).is_err());
210
211        let mut json = json!({
212            "type": "object",
213            "properties": {
214                "if": "...",
215            }
216        });
217
218        assert!(adapt_to_json_schema_subset(&mut json).is_err());
219
220        let mut json = json!({
221            "type": "object",
222            "properties": {
223                "then": "...",
224            }
225        });
226
227        assert!(adapt_to_json_schema_subset(&mut json).is_err());
228
229        let mut json = json!({
230            "type": "object",
231            "properties": {
232                "else": "...",
233            }
234        });
235
236        assert!(adapt_to_json_schema_subset(&mut json).is_err());
237    }
238}