tool_schema.rs

  1use anyhow::Result;
  2use schemars::{
  3    JsonSchema, Schema,
  4    generate::SchemaSettings,
  5    transform::{Transform, transform_subschemas},
  6};
  7use serde_json::Value;
  8
  9/// Indicates the format used to define the input schema for a language model tool.
 10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 11pub enum LanguageModelToolSchemaFormat {
 12    /// A JSON schema, see https://json-schema.org
 13    JsonSchema,
 14    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
 15    JsonSchemaSubset,
 16}
 17
 18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
 19    let mut generator = match format {
 20        LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
 21            .with(|settings| {
 22                settings.meta_schema = None;
 23                settings.inline_subschemas = true;
 24            })
 25            .into_generator(),
 26        LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
 27            .with(|settings| {
 28                settings.meta_schema = None;
 29                settings.inline_subschemas = true;
 30            })
 31            .with_transform(ToJsonSchemaSubsetTransform)
 32            .into_generator(),
 33    };
 34    generator.root_schema_for::<T>()
 35}
 36
 37#[derive(Debug, Clone)]
 38struct ToJsonSchemaSubsetTransform;
 39
 40impl Transform for ToJsonSchemaSubsetTransform {
 41    fn transform(&mut self, schema: &mut Schema) {
 42        // Ensure that the type field is not an array, this happens when we use
 43        // Option<T>, the type will be [T, "null"].
 44        if let Some(type_field) = schema.get_mut("type")
 45            && let Some(types) = type_field.as_array()
 46            && let Some(first_type) = types.first()
 47        {
 48            *type_field = first_type.clone();
 49        }
 50
 51        // oneOf is not supported, use anyOf instead
 52        if let Some(one_of) = schema.remove("oneOf") {
 53            schema.insert("anyOf".to_string(), one_of);
 54        }
 55
 56        transform_subschemas(self, schema);
 57    }
 58}
 59
 60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
 61///
 62/// If the json cannot be made compatible with the specified format, an error is returned.
 63pub fn adapt_schema_to_format(
 64    json: &mut Value,
 65    format: LanguageModelToolSchemaFormat,
 66) -> Result<()> {
 67    if let Value::Object(obj) = json {
 68        obj.remove("$schema");
 69        obj.remove("title");
 70    }
 71
 72    match format {
 73        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 74        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 75    }
 76}
 77
 78fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 79    // `additionalProperties` defaults to `false` unless explicitly specified.
 80    // This prevents models from hallucinating tool parameters.
 81    if let Value::Object(obj) = json
 82        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 83    {
 84        if !obj.contains_key("additionalProperties") {
 85            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 86        }
 87
 88        // OpenAI API requires non-missing `properties`
 89        if !obj.contains_key("properties") {
 90            obj.insert("properties".to_string(), Value::Object(Default::default()));
 91        }
 92    }
 93    Ok(())
 94}
 95
 96/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 97fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 98    if let Value::Object(obj) = json {
 99        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
100
101        for key in UNSUPPORTED_KEYS {
102            anyhow::ensure!(
103                !obj.contains_key(key),
104                "Schema cannot be made compatible because it contains \"{key}\""
105            );
106        }
107
108        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
109            ("format", |value| value.is_string()),
110            ("additionalProperties", |value| value.is_boolean()),
111            ("exclusiveMinimum", |value| value.is_number()),
112            ("exclusiveMaximum", |value| value.is_number()),
113            ("optional", |value| value.is_boolean()),
114        ];
115        for (key, predicate) in KEYS_TO_REMOVE {
116            if let Some(value) = obj.get(key)
117                && predicate(value)
118            {
119                obj.remove(key);
120            }
121        }
122
123        // If a type is not specified for an input parameter, add a default type
124        if matches!(obj.get("description"), Some(Value::String(_)))
125            && !obj.contains_key("type")
126            && !(obj.contains_key("anyOf")
127                || obj.contains_key("oneOf")
128                || obj.contains_key("allOf"))
129        {
130            obj.insert("type".to_string(), Value::String("string".to_string()));
131        }
132
133        // Handle oneOf -> anyOf conversion
134        if let Some(subschemas) = obj.get_mut("oneOf")
135            && subschemas.is_array()
136        {
137            let subschemas_clone = subschemas.clone();
138            obj.remove("oneOf");
139            obj.insert("anyOf".to_string(), subschemas_clone);
140        }
141
142        // Recursively process all nested objects and arrays
143        for (_, value) in obj.iter_mut() {
144            if let Value::Object(_) | Value::Array(_) = value {
145                adapt_to_json_schema_subset(value)?;
146            }
147        }
148    } else if let Value::Array(arr) = json {
149        for item in arr.iter_mut() {
150            adapt_to_json_schema_subset(item)?;
151        }
152    }
153    Ok(())
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use serde_json::json;
160
161    #[test]
162    fn test_transform_adds_type_when_missing() {
163        let mut json = json!({
164            "description": "A test field without type"
165        });
166
167        adapt_to_json_schema_subset(&mut json).unwrap();
168
169        assert_eq!(
170            json,
171            json!({
172                "description": "A test field without type",
173                "type": "string"
174            })
175        );
176
177        // Ensure that we do not add a type if it is an object
178        let mut json = json!({
179            "description": {
180                "value": "abc",
181                "type": "string"
182            }
183        });
184
185        adapt_to_json_schema_subset(&mut json).unwrap();
186
187        assert_eq!(
188            json,
189            json!({
190                "description": {
191                    "value": "abc",
192                    "type": "string"
193                }
194            })
195        );
196    }
197
198    #[test]
199    fn test_transform_removes_unsupported_keys() {
200        let mut json = json!({
201            "description": "A test field",
202            "type": "integer",
203            "format": "uint32",
204            "exclusiveMinimum": 0,
205            "exclusiveMaximum": 100,
206            "additionalProperties": false,
207            "optional": true
208        });
209
210        adapt_to_json_schema_subset(&mut json).unwrap();
211
212        assert_eq!(
213            json,
214            json!({
215                "description": "A test field",
216                "type": "integer"
217            })
218        );
219
220        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
221        let mut json = json!({
222            "description": "A test field",
223            "type": "integer",
224            "format": {},
225        });
226
227        adapt_to_json_schema_subset(&mut json).unwrap();
228
229        assert_eq!(
230            json,
231            json!({
232                "description": "A test field",
233                "type": "integer",
234                "format": {},
235            })
236        );
237    }
238
239    #[test]
240    fn test_transform_one_of_to_any_of() {
241        let mut json = json!({
242            "description": "A test field",
243            "oneOf": [
244                { "type": "string" },
245                { "type": "integer" }
246            ]
247        });
248
249        adapt_to_json_schema_subset(&mut json).unwrap();
250
251        assert_eq!(
252            json,
253            json!({
254                "description": "A test field",
255                "anyOf": [
256                    { "type": "string" },
257                    { "type": "integer" }
258                ]
259            })
260        );
261    }
262
263    #[test]
264    fn test_transform_nested_objects() {
265        let mut json = json!({
266            "type": "object",
267            "properties": {
268                "nested": {
269                    "oneOf": [
270                        { "type": "string" },
271                        { "type": "null" }
272                    ],
273                    "format": "email"
274                }
275            }
276        });
277
278        adapt_to_json_schema_subset(&mut json).unwrap();
279
280        assert_eq!(
281            json,
282            json!({
283                "type": "object",
284                "properties": {
285                    "nested": {
286                        "anyOf": [
287                            { "type": "string" },
288                            { "type": "null" }
289                        ]
290                    }
291                }
292            })
293        );
294    }
295
296    #[test]
297    fn test_transform_fails_if_unsupported_keys_exist() {
298        let mut json = json!({
299            "type": "object",
300            "properties": {
301                "$ref": "#/definitions/User",
302            }
303        });
304
305        assert!(adapt_to_json_schema_subset(&mut json).is_err());
306
307        let mut json = json!({
308            "type": "object",
309            "properties": {
310                "if": "...",
311            }
312        });
313
314        assert!(adapt_to_json_schema_subset(&mut json).is_err());
315
316        let mut json = json!({
317            "type": "object",
318            "properties": {
319                "then": "...",
320            }
321        });
322
323        assert!(adapt_to_json_schema_subset(&mut json).is_err());
324
325        let mut json = json!({
326            "type": "object",
327            "properties": {
328                "else": "...",
329            }
330        });
331
332        assert!(adapt_to_json_schema_subset(&mut json).is_err());
333    }
334
335    #[test]
336    fn test_preprocess_json_schema_adds_additional_properties() {
337        let mut json = json!({
338            "type": "object",
339            "properties": {
340                "name": {
341                    "type": "string"
342                }
343            }
344        });
345
346        preprocess_json_schema(&mut json).unwrap();
347
348        assert_eq!(
349            json,
350            json!({
351                "type": "object",
352                "properties": {
353                    "name": {
354                        "type": "string"
355                    }
356                },
357                "additionalProperties": false
358            })
359        );
360    }
361
362    #[test]
363    fn test_preprocess_json_schema_preserves_additional_properties() {
364        let mut json = json!({
365            "type": "object",
366            "properties": {
367                "name": {
368                    "type": "string"
369                }
370            },
371            "additionalProperties": true
372        });
373
374        preprocess_json_schema(&mut json).unwrap();
375
376        assert_eq!(
377            json,
378            json!({
379                "type": "object",
380                "properties": {
381                    "name": {
382                        "type": "string"
383                    }
384                },
385                "additionalProperties": true
386            })
387        );
388    }
389}