tool_schema.rs

  1use anyhow::Result;
  2use schemars::{
  3    JsonSchema, Schema,
  4    generate::SchemaSettings,
  5    transform::{Transform, transform_subschemas},
  6};
  7use serde_json::Value;
  8
  9/// Indicates the format used to define the input schema for a language model tool.
 10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 11pub enum LanguageModelToolSchemaFormat {
 12    /// A JSON schema, see https://json-schema.org
 13    JsonSchema,
 14    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
 15    JsonSchemaSubset,
 16}
 17
 18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
 19    let mut generator = match format {
 20        LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
 21            .with(|settings| {
 22                settings.meta_schema = None;
 23                settings.inline_subschemas = true;
 24            })
 25            .into_generator(),
 26        LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
 27            .with(|settings| {
 28                settings.meta_schema = None;
 29                settings.inline_subschemas = true;
 30            })
 31            .with_transform(ToJsonSchemaSubsetTransform)
 32            .into_generator(),
 33    };
 34    generator.root_schema_for::<T>()
 35}
 36
 37#[derive(Debug, Clone)]
 38struct ToJsonSchemaSubsetTransform;
 39
 40impl Transform for ToJsonSchemaSubsetTransform {
 41    fn transform(&mut self, schema: &mut Schema) {
 42        // Ensure that the type field is not an array, this happens when we use
 43        // Option<T>, the type will be [T, "null"].
 44        if let Some(type_field) = schema.get_mut("type")
 45            && let Some(types) = type_field.as_array()
 46            && let Some(first_type) = types.first()
 47        {
 48            *type_field = first_type.clone();
 49        }
 50
 51        // oneOf is not supported, use anyOf instead
 52        if let Some(one_of) = schema.remove("oneOf") {
 53            schema.insert("anyOf".to_string(), one_of);
 54        }
 55
 56        transform_subschemas(self, schema);
 57    }
 58}
 59
 60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
 61///
 62/// If the json cannot be made compatible with the specified format, an error is returned.
 63pub fn adapt_schema_to_format(
 64    json: &mut Value,
 65    format: LanguageModelToolSchemaFormat,
 66) -> Result<()> {
 67    if let Value::Object(obj) = json {
 68        obj.remove("$schema");
 69        obj.remove("title");
 70        obj.remove("description");
 71    }
 72
 73    match format {
 74        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 75        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 76    }
 77}
 78
 79fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 80    if let Value::Object(obj) = json
 81        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 82    {
 83        if !obj.contains_key("additionalProperties") {
 84            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 85        }
 86
 87        if !obj.contains_key("properties") {
 88            obj.insert("properties".to_string(), Value::Object(Default::default()));
 89        }
 90    }
 91    Ok(())
 92}
 93
 94fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 95    if let Value::Object(obj) = json {
 96        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 97
 98        for key in UNSUPPORTED_KEYS {
 99            anyhow::ensure!(
100                !obj.contains_key(key),
101                "Schema cannot be made compatible because it contains \"{key}\""
102            );
103        }
104
105        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
106            ("format", |value| value.is_string()),
107            ("additionalProperties", |_| true),
108            ("propertyNames", |_| true),
109            ("exclusiveMinimum", |value| value.is_number()),
110            ("exclusiveMaximum", |value| value.is_number()),
111            ("optional", |value| value.is_boolean()),
112        ];
113        for (key, predicate) in KEYS_TO_REMOVE {
114            if let Some(value) = obj.get(key)
115                && predicate(value)
116            {
117                obj.remove(key);
118            }
119        }
120
121        // Ensure that the type field is not an array. This can happen with MCP tool
122        // schemas that use multiple types (e.g. `["string", "number"]` or `["string", "null"]`).
123        if let Some(type_value) = obj.get_mut("type")
124            && let Some(types) = type_value.as_array()
125            && let Some(first_type) = types.first().cloned()
126        {
127            *type_value = first_type;
128        }
129
130        if matches!(obj.get("description"), Some(Value::String(_)))
131            && !obj.contains_key("type")
132            && !(obj.contains_key("anyOf")
133                || obj.contains_key("oneOf")
134                || obj.contains_key("allOf"))
135        {
136            obj.insert("type".to_string(), Value::String("string".to_string()));
137        }
138
139        if let Some(subschemas) = obj.get_mut("oneOf")
140            && subschemas.is_array()
141        {
142            let subschemas_clone = subschemas.clone();
143            obj.remove("oneOf");
144            obj.insert("anyOf".to_string(), subschemas_clone);
145        }
146
147        for (_, value) in obj.iter_mut() {
148            if let Value::Object(_) | Value::Array(_) = value {
149                adapt_to_json_schema_subset(value)?;
150            }
151        }
152    } else if let Value::Array(arr) = json {
153        for item in arr.iter_mut() {
154            adapt_to_json_schema_subset(item)?;
155        }
156    }
157    Ok(())
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use serde_json::json;
164
165    #[test]
166    fn test_transform_adds_type_when_missing() {
167        let mut json = json!({
168            "description": "A test field without type"
169        });
170
171        adapt_to_json_schema_subset(&mut json).unwrap();
172
173        assert_eq!(
174            json,
175            json!({
176                "description": "A test field without type",
177                "type": "string"
178            })
179        );
180
181        let mut json = json!({
182            "description": {
183                "value": "abc",
184                "type": "string"
185            }
186        });
187
188        adapt_to_json_schema_subset(&mut json).unwrap();
189
190        assert_eq!(
191            json,
192            json!({
193                "description": {
194                    "value": "abc",
195                    "type": "string"
196                }
197            })
198        );
199    }
200
201    #[test]
202    fn test_transform_removes_unsupported_keys() {
203        let mut json = json!({
204            "description": "A test field",
205            "type": "integer",
206            "format": "uint32",
207            "exclusiveMinimum": 0,
208            "exclusiveMaximum": 100,
209            "additionalProperties": false,
210            "optional": true
211        });
212
213        adapt_to_json_schema_subset(&mut json).unwrap();
214
215        assert_eq!(
216            json,
217            json!({
218                "description": "A test field",
219                "type": "integer"
220            })
221        );
222
223        let mut json = json!({
224            "description": "A test field",
225            "type": "integer",
226            "format": {},
227        });
228
229        adapt_to_json_schema_subset(&mut json).unwrap();
230
231        assert_eq!(
232            json,
233            json!({
234                "description": "A test field",
235                "type": "integer",
236                "format": {},
237            })
238        );
239
240        let mut json = json!({
241            "type": "object",
242            "properties": {
243                "name": { "type": "string" }
244            },
245            "additionalProperties": { "type": "string" },
246            "propertyNames": { "pattern": "^[A-Za-z]+$" }
247        });
248
249        adapt_to_json_schema_subset(&mut json).unwrap();
250
251        assert_eq!(
252            json,
253            json!({
254                "type": "object",
255                "properties": {
256                    "name": { "type": "string" }
257                }
258            })
259        );
260    }
261
262    #[test]
263    fn test_transform_one_of_to_any_of() {
264        let mut json = json!({
265            "description": "A test field",
266            "oneOf": [
267                { "type": "string" },
268                { "type": "integer" }
269            ]
270        });
271
272        adapt_to_json_schema_subset(&mut json).unwrap();
273
274        assert_eq!(
275            json,
276            json!({
277                "description": "A test field",
278                "anyOf": [
279                    { "type": "string" },
280                    { "type": "integer" }
281                ]
282            })
283        );
284    }
285
286    #[test]
287    fn test_transform_nested_objects() {
288        let mut json = json!({
289            "type": "object",
290            "properties": {
291                "nested": {
292                    "oneOf": [
293                        { "type": "string" },
294                        { "type": "null" }
295                    ],
296                    "format": "email"
297                }
298            }
299        });
300
301        adapt_to_json_schema_subset(&mut json).unwrap();
302
303        assert_eq!(
304            json,
305            json!({
306                "type": "object",
307                "properties": {
308                    "nested": {
309                        "anyOf": [
310                            { "type": "string" },
311                            { "type": "null" }
312                        ]
313                    }
314                }
315            })
316        );
317    }
318
319    #[test]
320    fn test_transform_array_type_to_single_type() {
321        let mut json = json!({
322            "type": "object",
323            "properties": {
324                "projectSlugOrId": {
325                    "type": ["string", "number"],
326                    "description": "Project slug or numeric ID"
327                },
328                "optionalName": {
329                    "type": ["string", "null"],
330                    "description": "An optional name"
331                }
332            }
333        });
334
335        adapt_to_json_schema_subset(&mut json).unwrap();
336
337        assert_eq!(
338            json,
339            json!({
340                "type": "object",
341                "properties": {
342                    "projectSlugOrId": {
343                        "type": "string",
344                        "description": "Project slug or numeric ID"
345                    },
346                    "optionalName": {
347                        "type": "string",
348                        "description": "An optional name"
349                    }
350                }
351            })
352        );
353    }
354
355    #[test]
356    fn test_transform_fails_if_unsupported_keys_exist() {
357        let mut json = json!({
358            "type": "object",
359            "properties": {
360                "$ref": "#/definitions/User",
361            }
362        });
363
364        assert!(adapt_to_json_schema_subset(&mut json).is_err());
365
366        let mut json = json!({
367            "type": "object",
368            "properties": {
369                "if": "...",
370            }
371        });
372
373        assert!(adapt_to_json_schema_subset(&mut json).is_err());
374
375        let mut json = json!({
376            "type": "object",
377            "properties": {
378                "then": "...",
379            }
380        });
381
382        assert!(adapt_to_json_schema_subset(&mut json).is_err());
383
384        let mut json = json!({
385            "type": "object",
386            "properties": {
387                "else": "...",
388            }
389        });
390
391        assert!(adapt_to_json_schema_subset(&mut json).is_err());
392    }
393
394    #[test]
395    fn test_preprocess_json_schema_adds_additional_properties() {
396        let mut json = json!({
397            "type": "object",
398            "properties": {
399                "name": {
400                    "type": "string"
401                }
402            }
403        });
404
405        preprocess_json_schema(&mut json).unwrap();
406
407        assert_eq!(
408            json,
409            json!({
410                "type": "object",
411                "properties": {
412                    "name": {
413                        "type": "string"
414                    }
415                },
416                "additionalProperties": false
417            })
418        );
419    }
420
421    #[test]
422    fn test_preprocess_json_schema_preserves_additional_properties() {
423        let mut json = json!({
424            "type": "object",
425            "properties": {
426                "name": {
427                    "type": "string"
428                }
429            },
430            "additionalProperties": true
431        });
432
433        preprocess_json_schema(&mut json).unwrap();
434
435        assert_eq!(
436            json,
437            json!({
438                "type": "object",
439                "properties": {
440                    "name": {
441                        "type": "string"
442                    }
443                },
444                "additionalProperties": true
445            })
446        );
447    }
448}