tool_schema.rs

  1use anyhow::Result;
  2use schemars::{
  3    JsonSchema, Schema,
  4    generate::SchemaSettings,
  5    transform::{Transform, transform_subschemas},
  6};
  7use serde_json::Value;
  8
  9/// Indicates the format used to define the input schema for a language model tool.
 10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 11pub enum LanguageModelToolSchemaFormat {
 12    /// A JSON schema, see https://json-schema.org
 13    JsonSchema,
 14    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
 15    JsonSchemaSubset,
 16}
 17
 18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
 19    let mut generator = match format {
 20        LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
 21            .with(|settings| {
 22                settings.meta_schema = None;
 23                settings.inline_subschemas = true;
 24            })
 25            .into_generator(),
 26        LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
 27            .with(|settings| {
 28                settings.meta_schema = None;
 29                settings.inline_subschemas = true;
 30            })
 31            .with_transform(ToJsonSchemaSubsetTransform)
 32            .into_generator(),
 33    };
 34    generator.root_schema_for::<T>()
 35}
 36
 37#[derive(Debug, Clone)]
 38struct ToJsonSchemaSubsetTransform;
 39
 40impl Transform for ToJsonSchemaSubsetTransform {
 41    fn transform(&mut self, schema: &mut Schema) {
 42        // Ensure that the type field is not an array, this happens when we use
 43        // Option<T>, the type will be [T, "null"].
 44        if let Some(type_field) = schema.get_mut("type")
 45            && let Some(types) = type_field.as_array()
 46            && let Some(first_type) = types.first()
 47        {
 48            *type_field = first_type.clone();
 49        }
 50
 51        // oneOf is not supported, use anyOf instead
 52        if let Some(one_of) = schema.remove("oneOf") {
 53            schema.insert("anyOf".to_string(), one_of);
 54        }
 55
 56        transform_subschemas(self, schema);
 57    }
 58}
 59
 60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
 61///
 62/// If the json cannot be made compatible with the specified format, an error is returned.
 63pub fn adapt_schema_to_format(
 64    json: &mut Value,
 65    format: LanguageModelToolSchemaFormat,
 66) -> Result<()> {
 67    if let Value::Object(obj) = json {
 68        obj.remove("$schema");
 69        obj.remove("title");
 70        obj.remove("description");
 71    }
 72
 73    match format {
 74        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 75        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 76    }
 77}
 78
 79fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 80    // `additionalProperties` defaults to `false` unless explicitly specified.
 81    // This prevents models from hallucinating tool parameters.
 82    if let Value::Object(obj) = json
 83        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 84    {
 85        if !obj.contains_key("additionalProperties") {
 86            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 87        }
 88
 89        // OpenAI API requires non-missing `properties`
 90        if !obj.contains_key("properties") {
 91            obj.insert("properties".to_string(), Value::Object(Default::default()));
 92        }
 93    }
 94    Ok(())
 95}
 96
 97/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 98fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 99    if let Value::Object(obj) = json {
100        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
101
102        for key in UNSUPPORTED_KEYS {
103            anyhow::ensure!(
104                !obj.contains_key(key),
105                "Schema cannot be made compatible because it contains \"{key}\""
106            );
107        }
108
109        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
110            ("format", |value| value.is_string()),
111            // Gemini doesn't support `additionalProperties` in any form (boolean or schema object)
112            ("additionalProperties", |_| true),
113            // Gemini doesn't support `propertyNames`
114            ("propertyNames", |_| true),
115            ("exclusiveMinimum", |value| value.is_number()),
116            ("exclusiveMaximum", |value| value.is_number()),
117            ("optional", |value| value.is_boolean()),
118        ];
119        for (key, predicate) in KEYS_TO_REMOVE {
120            if let Some(value) = obj.get(key)
121                && predicate(value)
122            {
123                obj.remove(key);
124            }
125        }
126
127        // If a type is not specified for an input parameter, add a default type
128        if matches!(obj.get("description"), Some(Value::String(_)))
129            && !obj.contains_key("type")
130            && !(obj.contains_key("anyOf")
131                || obj.contains_key("oneOf")
132                || obj.contains_key("allOf"))
133        {
134            obj.insert("type".to_string(), Value::String("string".to_string()));
135        }
136
137        // Handle oneOf -> anyOf conversion
138        if let Some(subschemas) = obj.get_mut("oneOf")
139            && subschemas.is_array()
140        {
141            let subschemas_clone = subschemas.clone();
142            obj.remove("oneOf");
143            obj.insert("anyOf".to_string(), subschemas_clone);
144        }
145
146        // Recursively process all nested objects and arrays
147        for (_, value) in obj.iter_mut() {
148            if let Value::Object(_) | Value::Array(_) = value {
149                adapt_to_json_schema_subset(value)?;
150            }
151        }
152    } else if let Value::Array(arr) = json {
153        for item in arr.iter_mut() {
154            adapt_to_json_schema_subset(item)?;
155        }
156    }
157    Ok(())
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use serde_json::json;
164
165    #[test]
166    fn test_transform_adds_type_when_missing() {
167        let mut json = json!({
168            "description": "A test field without type"
169        });
170
171        adapt_to_json_schema_subset(&mut json).unwrap();
172
173        assert_eq!(
174            json,
175            json!({
176                "description": "A test field without type",
177                "type": "string"
178            })
179        );
180
181        // Ensure that we do not add a type if it is an object
182        let mut json = json!({
183            "description": {
184                "value": "abc",
185                "type": "string"
186            }
187        });
188
189        adapt_to_json_schema_subset(&mut json).unwrap();
190
191        assert_eq!(
192            json,
193            json!({
194                "description": {
195                    "value": "abc",
196                    "type": "string"
197                }
198            })
199        );
200    }
201
202    #[test]
203    fn test_transform_removes_unsupported_keys() {
204        let mut json = json!({
205            "description": "A test field",
206            "type": "integer",
207            "format": "uint32",
208            "exclusiveMinimum": 0,
209            "exclusiveMaximum": 100,
210            "additionalProperties": false,
211            "optional": true
212        });
213
214        adapt_to_json_schema_subset(&mut json).unwrap();
215
216        assert_eq!(
217            json,
218            json!({
219                "description": "A test field",
220                "type": "integer"
221            })
222        );
223
224        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
225        let mut json = json!({
226            "description": "A test field",
227            "type": "integer",
228            "format": {},
229        });
230
231        adapt_to_json_schema_subset(&mut json).unwrap();
232
233        assert_eq!(
234            json,
235            json!({
236                "description": "A test field",
237                "type": "integer",
238                "format": {},
239            })
240        );
241
242        // additionalProperties as an object schema is also unsupported by Gemini
243        let mut json = json!({
244            "type": "object",
245            "properties": {
246                "name": { "type": "string" }
247            },
248            "additionalProperties": { "type": "string" },
249            "propertyNames": { "pattern": "^[A-Za-z]+$" }
250        });
251
252        adapt_to_json_schema_subset(&mut json).unwrap();
253
254        assert_eq!(
255            json,
256            json!({
257                "type": "object",
258                "properties": {
259                    "name": { "type": "string" }
260                }
261            })
262        );
263    }
264
265    #[test]
266    fn test_transform_one_of_to_any_of() {
267        let mut json = json!({
268            "description": "A test field",
269            "oneOf": [
270                { "type": "string" },
271                { "type": "integer" }
272            ]
273        });
274
275        adapt_to_json_schema_subset(&mut json).unwrap();
276
277        assert_eq!(
278            json,
279            json!({
280                "description": "A test field",
281                "anyOf": [
282                    { "type": "string" },
283                    { "type": "integer" }
284                ]
285            })
286        );
287    }
288
289    #[test]
290    fn test_transform_nested_objects() {
291        let mut json = json!({
292            "type": "object",
293            "properties": {
294                "nested": {
295                    "oneOf": [
296                        { "type": "string" },
297                        { "type": "null" }
298                    ],
299                    "format": "email"
300                }
301            }
302        });
303
304        adapt_to_json_schema_subset(&mut json).unwrap();
305
306        assert_eq!(
307            json,
308            json!({
309                "type": "object",
310                "properties": {
311                    "nested": {
312                        "anyOf": [
313                            { "type": "string" },
314                            { "type": "null" }
315                        ]
316                    }
317                }
318            })
319        );
320    }
321
322    #[test]
323    fn test_transform_fails_if_unsupported_keys_exist() {
324        let mut json = json!({
325            "type": "object",
326            "properties": {
327                "$ref": "#/definitions/User",
328            }
329        });
330
331        assert!(adapt_to_json_schema_subset(&mut json).is_err());
332
333        let mut json = json!({
334            "type": "object",
335            "properties": {
336                "if": "...",
337            }
338        });
339
340        assert!(adapt_to_json_schema_subset(&mut json).is_err());
341
342        let mut json = json!({
343            "type": "object",
344            "properties": {
345                "then": "...",
346            }
347        });
348
349        assert!(adapt_to_json_schema_subset(&mut json).is_err());
350
351        let mut json = json!({
352            "type": "object",
353            "properties": {
354                "else": "...",
355            }
356        });
357
358        assert!(adapt_to_json_schema_subset(&mut json).is_err());
359    }
360
361    #[test]
362    fn test_preprocess_json_schema_adds_additional_properties() {
363        let mut json = json!({
364            "type": "object",
365            "properties": {
366                "name": {
367                    "type": "string"
368                }
369            }
370        });
371
372        preprocess_json_schema(&mut json).unwrap();
373
374        assert_eq!(
375            json,
376            json!({
377                "type": "object",
378                "properties": {
379                    "name": {
380                        "type": "string"
381                    }
382                },
383                "additionalProperties": false
384            })
385        );
386    }
387
388    #[test]
389    fn test_preprocess_json_schema_preserves_additional_properties() {
390        let mut json = json!({
391            "type": "object",
392            "properties": {
393                "name": {
394                    "type": "string"
395                }
396            },
397            "additionalProperties": true
398        });
399
400        preprocess_json_schema(&mut json).unwrap();
401
402        assert_eq!(
403            json,
404            json!({
405                "type": "object",
406                "properties": {
407                    "name": {
408                        "type": "string"
409                    }
410                },
411                "additionalProperties": true
412            })
413        );
414    }
415}