tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    match format {
 14        LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
 15        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 16    }
 17}
 18
 19/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 20fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 21    if let Value::Object(obj) = json {
 22        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 23
 24        for key in UNSUPPORTED_KEYS {
 25            if obj.contains_key(key) {
 26                return Err(anyhow::anyhow!(
 27                    "Schema cannot be made compatible because it contains \"{}\" ",
 28                    key
 29                ));
 30            }
 31        }
 32
 33        const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
 34        for key in KEYS_TO_REMOVE {
 35            obj.remove(key);
 36        }
 37
 38        if let Some(default) = obj.get("default") {
 39            let is_null = default.is_null();
 40            // Default is not supported, so we need to remove it
 41            obj.remove("default");
 42            if is_null {
 43                obj.insert("nullable".to_string(), Value::Bool(true));
 44            }
 45        }
 46
 47        // If a type is not specified for an input parameter, add a default type
 48        if obj.contains_key("description")
 49            && !obj.contains_key("type")
 50            && !(obj.contains_key("anyOf")
 51                || obj.contains_key("oneOf")
 52                || obj.contains_key("allOf"))
 53        {
 54            obj.insert("type".to_string(), Value::String("string".to_string()));
 55        }
 56
 57        // Handle oneOf -> anyOf conversion
 58        if let Some(subschemas) = obj.get_mut("oneOf") {
 59            if subschemas.is_array() {
 60                let subschemas_clone = subschemas.clone();
 61                obj.remove("oneOf");
 62                obj.insert("anyOf".to_string(), subschemas_clone);
 63            }
 64        }
 65
 66        // Recursively process all nested objects and arrays
 67        for (_, value) in obj.iter_mut() {
 68            if let Value::Object(_) | Value::Array(_) = value {
 69                adapt_to_json_schema_subset(value)?;
 70            }
 71        }
 72    } else if let Value::Array(arr) = json {
 73        for item in arr.iter_mut() {
 74            adapt_to_json_schema_subset(item)?;
 75        }
 76    }
 77    Ok(())
 78}
 79
 80#[cfg(test)]
 81mod tests {
 82    use super::*;
 83    use serde_json::json;
 84
 85    #[test]
 86    fn test_transform_default_null_to_nullable() {
 87        let mut json = json!({
 88            "description": "A test field",
 89            "type": "string",
 90            "default": null
 91        });
 92
 93        adapt_to_json_schema_subset(&mut json).unwrap();
 94
 95        assert_eq!(
 96            json,
 97            json!({
 98                "description": "A test field",
 99                "type": "string",
100                "nullable": true
101            })
102        );
103    }
104
105    #[test]
106    fn test_transform_adds_type_when_missing() {
107        let mut json = json!({
108            "description": "A test field without type"
109        });
110
111        adapt_to_json_schema_subset(&mut json).unwrap();
112
113        assert_eq!(
114            json,
115            json!({
116                "description": "A test field without type",
117                "type": "string"
118            })
119        );
120    }
121
122    #[test]
123    fn test_transform_removes_format() {
124        let mut json = json!({
125            "description": "A test field",
126            "type": "integer",
127            "format": "uint32"
128        });
129
130        adapt_to_json_schema_subset(&mut json).unwrap();
131
132        assert_eq!(
133            json,
134            json!({
135                "description": "A test field",
136                "type": "integer"
137            })
138        );
139    }
140
141    #[test]
142    fn test_transform_one_of_to_any_of() {
143        let mut json = json!({
144            "description": "A test field",
145            "oneOf": [
146                { "type": "string" },
147                { "type": "integer" }
148            ]
149        });
150
151        adapt_to_json_schema_subset(&mut json).unwrap();
152
153        assert_eq!(
154            json,
155            json!({
156                "description": "A test field",
157                "anyOf": [
158                    { "type": "string" },
159                    { "type": "integer" }
160                ]
161            })
162        );
163    }
164
165    #[test]
166    fn test_transform_nested_objects() {
167        let mut json = json!({
168            "type": "object",
169            "properties": {
170                "nested": {
171                    "oneOf": [
172                        { "type": "string" },
173                        { "type": "null" }
174                    ],
175                    "format": "email"
176                }
177            }
178        });
179
180        adapt_to_json_schema_subset(&mut json).unwrap();
181
182        assert_eq!(
183            json,
184            json!({
185                "type": "object",
186                "properties": {
187                    "nested": {
188                        "anyOf": [
189                            { "type": "string" },
190                            { "type": "null" }
191                        ]
192                    }
193                }
194            })
195        );
196    }
197
198    #[test]
199    fn test_transform_fails_if_unsupported_keys_exist() {
200        let mut json = json!({
201            "type": "object",
202            "properties": {
203                "$ref": "#/definitions/User",
204            }
205        });
206
207        assert!(adapt_to_json_schema_subset(&mut json).is_err());
208
209        let mut json = json!({
210            "type": "object",
211            "properties": {
212                "if": "...",
213            }
214        });
215
216        assert!(adapt_to_json_schema_subset(&mut json).is_err());
217
218        let mut json = json!({
219            "type": "object",
220            "properties": {
221                "then": "...",
222            }
223        });
224
225        assert!(adapt_to_json_schema_subset(&mut json).is_err());
226
227        let mut json = json!({
228            "type": "object",
229            "properties": {
230                "else": "...",
231            }
232        });
233
234        assert!(adapt_to_json_schema_subset(&mut json).is_err());
235    }
236}