tool_schema.rs

  1use anyhow::Result;
  2use schemars::{
  3    JsonSchema, Schema,
  4    generate::SchemaSettings,
  5    transform::{Transform, transform_subschemas},
  6};
  7use serde_json::Value;
  8
  9/// Indicates the format used to define the input schema for a language model tool.
 10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 11pub enum LanguageModelToolSchemaFormat {
 12    /// A JSON schema, see https://json-schema.org
 13    JsonSchema,
 14    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
 15    JsonSchemaSubset,
 16}
 17
 18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
 19    let mut generator = match format {
 20        LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
 21            .with(|settings| {
 22                settings.meta_schema = None;
 23                settings.inline_subschemas = true;
 24            })
 25            .into_generator(),
 26        LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
 27            .with(|settings| {
 28                settings.meta_schema = None;
 29                settings.inline_subschemas = true;
 30            })
 31            .with_transform(ToJsonSchemaSubsetTransform)
 32            .into_generator(),
 33    };
 34    generator.root_schema_for::<T>()
 35}
 36
 37#[derive(Debug, Clone)]
 38struct ToJsonSchemaSubsetTransform;
 39
 40impl Transform for ToJsonSchemaSubsetTransform {
 41    fn transform(&mut self, schema: &mut Schema) {
 42        // Ensure that the type field is not an array, this happens when we use
 43        // Option<T>, the type will be [T, "null"].
 44        if let Some(type_field) = schema.get_mut("type")
 45            && let Some(types) = type_field.as_array()
 46            && let Some(first_type) = types.first()
 47        {
 48            *type_field = first_type.clone();
 49        }
 50
 51        // oneOf is not supported, use anyOf instead
 52        if let Some(one_of) = schema.remove("oneOf") {
 53            schema.insert("anyOf".to_string(), one_of);
 54        }
 55
 56        transform_subschemas(self, schema);
 57    }
 58}
 59
 60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
 61///
 62/// If the json cannot be made compatible with the specified format, an error is returned.
 63pub fn adapt_schema_to_format(
 64    json: &mut Value,
 65    format: LanguageModelToolSchemaFormat,
 66) -> Result<()> {
 67    if let Value::Object(obj) = json {
 68        obj.remove("$schema");
 69        obj.remove("title");
 70        obj.remove("description");
 71    }
 72
 73    match format {
 74        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 75        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 76    }
 77}
 78
 79fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 80    if let Value::Object(obj) = json
 81        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 82    {
 83        if !obj.contains_key("additionalProperties") {
 84            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 85        }
 86
 87        if !obj.contains_key("properties") {
 88            obj.insert("properties".to_string(), Value::Object(Default::default()));
 89        }
 90    }
 91    Ok(())
 92}
 93
 94fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 95    if let Value::Object(obj) = json {
 96        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 97
 98        for key in UNSUPPORTED_KEYS {
 99            anyhow::ensure!(
100                !obj.contains_key(key),
101                "Schema cannot be made compatible because it contains \"{key}\""
102            );
103        }
104
105        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
106            ("format", |value| value.is_string()),
107            ("additionalProperties", |_| true),
108            ("propertyNames", |_| true),
109            ("exclusiveMinimum", |value| value.is_number()),
110            ("exclusiveMaximum", |value| value.is_number()),
111            ("optional", |value| value.is_boolean()),
112        ];
113        for (key, predicate) in KEYS_TO_REMOVE {
114            if let Some(value) = obj.get(key)
115                && predicate(value)
116            {
117                obj.remove(key);
118            }
119        }
120
121        if matches!(obj.get("description"), Some(Value::String(_)))
122            && !obj.contains_key("type")
123            && !(obj.contains_key("anyOf")
124                || obj.contains_key("oneOf")
125                || obj.contains_key("allOf"))
126        {
127            obj.insert("type".to_string(), Value::String("string".to_string()));
128        }
129
130        if let Some(subschemas) = obj.get_mut("oneOf")
131            && subschemas.is_array()
132        {
133            let subschemas_clone = subschemas.clone();
134            obj.remove("oneOf");
135            obj.insert("anyOf".to_string(), subschemas_clone);
136        }
137
138        for (_, value) in obj.iter_mut() {
139            if let Value::Object(_) | Value::Array(_) = value {
140                adapt_to_json_schema_subset(value)?;
141            }
142        }
143    } else if let Value::Array(arr) = json {
144        for item in arr.iter_mut() {
145            adapt_to_json_schema_subset(item)?;
146        }
147    }
148    Ok(())
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use serde_json::json;
155
156    #[test]
157    fn test_transform_adds_type_when_missing() {
158        let mut json = json!({
159            "description": "A test field without type"
160        });
161
162        adapt_to_json_schema_subset(&mut json).unwrap();
163
164        assert_eq!(
165            json,
166            json!({
167                "description": "A test field without type",
168                "type": "string"
169            })
170        );
171
172        let mut json = json!({
173            "description": {
174                "value": "abc",
175                "type": "string"
176            }
177        });
178
179        adapt_to_json_schema_subset(&mut json).unwrap();
180
181        assert_eq!(
182            json,
183            json!({
184                "description": {
185                    "value": "abc",
186                    "type": "string"
187                }
188            })
189        );
190    }
191
192    #[test]
193    fn test_transform_removes_unsupported_keys() {
194        let mut json = json!({
195            "description": "A test field",
196            "type": "integer",
197            "format": "uint32",
198            "exclusiveMinimum": 0,
199            "exclusiveMaximum": 100,
200            "additionalProperties": false,
201            "optional": true
202        });
203
204        adapt_to_json_schema_subset(&mut json).unwrap();
205
206        assert_eq!(
207            json,
208            json!({
209                "description": "A test field",
210                "type": "integer"
211            })
212        );
213
214        let mut json = json!({
215            "description": "A test field",
216            "type": "integer",
217            "format": {},
218        });
219
220        adapt_to_json_schema_subset(&mut json).unwrap();
221
222        assert_eq!(
223            json,
224            json!({
225                "description": "A test field",
226                "type": "integer",
227                "format": {},
228            })
229        );
230
231        let mut json = json!({
232            "type": "object",
233            "properties": {
234                "name": { "type": "string" }
235            },
236            "additionalProperties": { "type": "string" },
237            "propertyNames": { "pattern": "^[A-Za-z]+$" }
238        });
239
240        adapt_to_json_schema_subset(&mut json).unwrap();
241
242        assert_eq!(
243            json,
244            json!({
245                "type": "object",
246                "properties": {
247                    "name": { "type": "string" }
248                }
249            })
250        );
251    }
252
253    #[test]
254    fn test_transform_one_of_to_any_of() {
255        let mut json = json!({
256            "description": "A test field",
257            "oneOf": [
258                { "type": "string" },
259                { "type": "integer" }
260            ]
261        });
262
263        adapt_to_json_schema_subset(&mut json).unwrap();
264
265        assert_eq!(
266            json,
267            json!({
268                "description": "A test field",
269                "anyOf": [
270                    { "type": "string" },
271                    { "type": "integer" }
272                ]
273            })
274        );
275    }
276
277    #[test]
278    fn test_transform_nested_objects() {
279        let mut json = json!({
280            "type": "object",
281            "properties": {
282                "nested": {
283                    "oneOf": [
284                        { "type": "string" },
285                        { "type": "null" }
286                    ],
287                    "format": "email"
288                }
289            }
290        });
291
292        adapt_to_json_schema_subset(&mut json).unwrap();
293
294        assert_eq!(
295            json,
296            json!({
297                "type": "object",
298                "properties": {
299                    "nested": {
300                        "anyOf": [
301                            { "type": "string" },
302                            { "type": "null" }
303                        ]
304                    }
305                }
306            })
307        );
308    }
309
310    #[test]
311    fn test_transform_fails_if_unsupported_keys_exist() {
312        let mut json = json!({
313            "type": "object",
314            "properties": {
315                "$ref": "#/definitions/User",
316            }
317        });
318
319        assert!(adapt_to_json_schema_subset(&mut json).is_err());
320
321        let mut json = json!({
322            "type": "object",
323            "properties": {
324                "if": "...",
325            }
326        });
327
328        assert!(adapt_to_json_schema_subset(&mut json).is_err());
329
330        let mut json = json!({
331            "type": "object",
332            "properties": {
333                "then": "...",
334            }
335        });
336
337        assert!(adapt_to_json_schema_subset(&mut json).is_err());
338
339        let mut json = json!({
340            "type": "object",
341            "properties": {
342                "else": "...",
343            }
344        });
345
346        assert!(adapt_to_json_schema_subset(&mut json).is_err());
347    }
348
349    #[test]
350    fn test_preprocess_json_schema_adds_additional_properties() {
351        let mut json = json!({
352            "type": "object",
353            "properties": {
354                "name": {
355                    "type": "string"
356                }
357            }
358        });
359
360        preprocess_json_schema(&mut json).unwrap();
361
362        assert_eq!(
363            json,
364            json!({
365                "type": "object",
366                "properties": {
367                    "name": {
368                        "type": "string"
369                    }
370                },
371                "additionalProperties": false
372            })
373        );
374    }
375
376    #[test]
377    fn test_preprocess_json_schema_preserves_additional_properties() {
378        let mut json = json!({
379            "type": "object",
380            "properties": {
381                "name": {
382                    "type": "string"
383                }
384            },
385            "additionalProperties": true
386        });
387
388        preprocess_json_schema(&mut json).unwrap();
389
390        assert_eq!(
391            json,
392            json!({
393                "type": "object",
394                "properties": {
395                    "name": {
396                        "type": "string"
397                    }
398                },
399                "additionalProperties": true
400            })
401        );
402    }
403}