tool_schema.rs

  1use anyhow::Result;
  2use serde_json::Value;
  3
  4use crate::LanguageModelToolSchemaFormat;
  5
  6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
  7///
  8/// If the json cannot be made compatible with the specified format, an error is returned.
  9pub fn adapt_schema_to_format(
 10    json: &mut Value,
 11    format: LanguageModelToolSchemaFormat,
 12) -> Result<()> {
 13    if let Value::Object(obj) = json {
 14        obj.remove("$schema");
 15        obj.remove("title");
 16    }
 17
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 21    }
 22}
 23
 24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 26    if let Value::Object(obj) = json {
 27        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 28
 29        for key in UNSUPPORTED_KEYS {
 30            anyhow::ensure!(
 31                !obj.contains_key(key),
 32                "Schema cannot be made compatible because it contains \"{key}\""
 33            );
 34        }
 35
 36        const KEYS_TO_REMOVE: [&str; 5] = [
 37            "format",
 38            "additionalProperties",
 39            "exclusiveMinimum",
 40            "exclusiveMaximum",
 41            "optional",
 42        ];
 43        for key in KEYS_TO_REMOVE {
 44            obj.remove(key);
 45        }
 46
 47        // If a type is not specified for an input parameter, add a default type
 48        if matches!(obj.get("description"), Some(Value::String(_)))
 49            && !obj.contains_key("type")
 50            && !(obj.contains_key("anyOf")
 51                || obj.contains_key("oneOf")
 52                || obj.contains_key("allOf"))
 53        {
 54            obj.insert("type".to_string(), Value::String("string".to_string()));
 55        }
 56
 57        // Handle oneOf -> anyOf conversion
 58        if let Some(subschemas) = obj.get_mut("oneOf") {
 59            if subschemas.is_array() {
 60                let subschemas_clone = subschemas.clone();
 61                obj.remove("oneOf");
 62                obj.insert("anyOf".to_string(), subschemas_clone);
 63            }
 64        }
 65
 66        // Recursively process all nested objects and arrays
 67        for (_, value) in obj.iter_mut() {
 68            if let Value::Object(_) | Value::Array(_) = value {
 69                adapt_to_json_schema_subset(value)?;
 70            }
 71        }
 72    } else if let Value::Array(arr) = json {
 73        for item in arr.iter_mut() {
 74            adapt_to_json_schema_subset(item)?;
 75        }
 76    }
 77    Ok(())
 78}
 79
 80#[cfg(test)]
 81mod tests {
 82    use super::*;
 83    use serde_json::json;
 84
 85    #[test]
 86    fn test_transform_adds_type_when_missing() {
 87        let mut json = json!({
 88            "description": "A test field without type"
 89        });
 90
 91        adapt_to_json_schema_subset(&mut json).unwrap();
 92
 93        assert_eq!(
 94            json,
 95            json!({
 96                "description": "A test field without type",
 97                "type": "string"
 98            })
 99        );
100
101        // Ensure that we do not add a type if it is an object
102        let mut json = json!({
103            "description": {
104                "value": "abc",
105                "type": "string"
106            }
107        });
108
109        adapt_to_json_schema_subset(&mut json).unwrap();
110
111        assert_eq!(
112            json,
113            json!({
114                "description": {
115                    "value": "abc",
116                    "type": "string"
117                }
118            })
119        );
120    }
121
122    #[test]
123    fn test_transform_removes_unsupported_keys() {
124        let mut json = json!({
125            "description": "A test field",
126            "type": "integer",
127            "format": "uint32",
128            "exclusiveMinimum": 0,
129            "exclusiveMaximum": 100,
130            "additionalProperties": false,
131            "optional": true
132        });
133
134        adapt_to_json_schema_subset(&mut json).unwrap();
135
136        assert_eq!(
137            json,
138            json!({
139                "description": "A test field",
140                "type": "integer"
141            })
142        );
143    }
144
145    #[test]
146    fn test_transform_one_of_to_any_of() {
147        let mut json = json!({
148            "description": "A test field",
149            "oneOf": [
150                { "type": "string" },
151                { "type": "integer" }
152            ]
153        });
154
155        adapt_to_json_schema_subset(&mut json).unwrap();
156
157        assert_eq!(
158            json,
159            json!({
160                "description": "A test field",
161                "anyOf": [
162                    { "type": "string" },
163                    { "type": "integer" }
164                ]
165            })
166        );
167    }
168
169    #[test]
170    fn test_transform_nested_objects() {
171        let mut json = json!({
172            "type": "object",
173            "properties": {
174                "nested": {
175                    "oneOf": [
176                        { "type": "string" },
177                        { "type": "null" }
178                    ],
179                    "format": "email"
180                }
181            }
182        });
183
184        adapt_to_json_schema_subset(&mut json).unwrap();
185
186        assert_eq!(
187            json,
188            json!({
189                "type": "object",
190                "properties": {
191                    "nested": {
192                        "anyOf": [
193                            { "type": "string" },
194                            { "type": "null" }
195                        ]
196                    }
197                }
198            })
199        );
200    }
201
202    #[test]
203    fn test_transform_fails_if_unsupported_keys_exist() {
204        let mut json = json!({
205            "type": "object",
206            "properties": {
207                "$ref": "#/definitions/User",
208            }
209        });
210
211        assert!(adapt_to_json_schema_subset(&mut json).is_err());
212
213        let mut json = json!({
214            "type": "object",
215            "properties": {
216                "if": "...",
217            }
218        });
219
220        assert!(adapt_to_json_schema_subset(&mut json).is_err());
221
222        let mut json = json!({
223            "type": "object",
224            "properties": {
225                "then": "...",
226            }
227        });
228
229        assert!(adapt_to_json_schema_subset(&mut json).is_err());
230
231        let mut json = json!({
232            "type": "object",
233            "properties": {
234                "else": "...",
235            }
236        });
237
238        assert!(adapt_to_json_schema_subset(&mut json).is_err());
239    }
240}