tool_schema.rs

  1use anyhow::Result;
  2use schemars::{
  3    JsonSchema, Schema,
  4    generate::SchemaSettings,
  5    transform::{Transform, transform_subschemas},
  6};
  7use serde_json::Value;
  8
  9/// Indicates the format used to define the input schema for a language model tool.
 10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 11pub enum LanguageModelToolSchemaFormat {
 12    /// A JSON schema, see https://json-schema.org
 13    JsonSchema,
 14    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
 15    JsonSchemaSubset,
 16}
 17
 18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
 19    let mut generator = match format {
 20        LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
 21        LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
 22            .with(|settings| {
 23                settings.meta_schema = None;
 24                settings.inline_subschemas = true;
 25            })
 26            .with_transform(ToJsonSchemaSubsetTransform)
 27            .into_generator(),
 28    };
 29    generator.root_schema_for::<T>()
 30}
 31
 32#[derive(Debug, Clone)]
 33struct ToJsonSchemaSubsetTransform;
 34
 35impl Transform for ToJsonSchemaSubsetTransform {
 36    fn transform(&mut self, schema: &mut Schema) {
 37        // Ensure that the type field is not an array, this happens when we use
 38        // Option<T>, the type will be [T, "null"].
 39        if let Some(type_field) = schema.get_mut("type")
 40            && let Some(types) = type_field.as_array()
 41            && let Some(first_type) = types.first()
 42        {
 43            *type_field = first_type.clone();
 44        }
 45
 46        // oneOf is not supported, use anyOf instead
 47        if let Some(one_of) = schema.remove("oneOf") {
 48            schema.insert("anyOf".to_string(), one_of);
 49        }
 50
 51        transform_subschemas(self, schema);
 52    }
 53}
 54
 55/// Tries to adapt a JSON schema representation to be compatible with the specified format.
 56///
 57/// If the json cannot be made compatible with the specified format, an error is returned.
 58pub fn adapt_schema_to_format(
 59    json: &mut Value,
 60    format: LanguageModelToolSchemaFormat,
 61) -> Result<()> {
 62    if let Value::Object(obj) = json {
 63        obj.remove("$schema");
 64        obj.remove("title");
 65    }
 66
 67    match format {
 68        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 69        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 70    }
 71}
 72
 73fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 74    // `additionalProperties` defaults to `false` unless explicitly specified.
 75    // This prevents models from hallucinating tool parameters.
 76    if let Value::Object(obj) = json
 77        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 78    {
 79        if !obj.contains_key("additionalProperties") {
 80            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 81        }
 82
 83        // OpenAI API requires non-missing `properties`
 84        if !obj.contains_key("properties") {
 85            obj.insert("properties".to_string(), Value::Object(Default::default()));
 86        }
 87    }
 88    Ok(())
 89}
 90
 91/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 92fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 93    if let Value::Object(obj) = json {
 94        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 95
 96        for key in UNSUPPORTED_KEYS {
 97            anyhow::ensure!(
 98                !obj.contains_key(key),
 99                "Schema cannot be made compatible because it contains \"{key}\""
100            );
101        }
102
103        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
104            ("format", |value| value.is_string()),
105            ("additionalProperties", |value| value.is_boolean()),
106            ("exclusiveMinimum", |value| value.is_number()),
107            ("exclusiveMaximum", |value| value.is_number()),
108            ("optional", |value| value.is_boolean()),
109        ];
110        for (key, predicate) in KEYS_TO_REMOVE {
111            if let Some(value) = obj.get(key)
112                && predicate(value)
113            {
114                obj.remove(key);
115            }
116        }
117
118        // If a type is not specified for an input parameter, add a default type
119        if matches!(obj.get("description"), Some(Value::String(_)))
120            && !obj.contains_key("type")
121            && !(obj.contains_key("anyOf")
122                || obj.contains_key("oneOf")
123                || obj.contains_key("allOf"))
124        {
125            obj.insert("type".to_string(), Value::String("string".to_string()));
126        }
127
128        // Handle oneOf -> anyOf conversion
129        if let Some(subschemas) = obj.get_mut("oneOf")
130            && subschemas.is_array()
131        {
132            let subschemas_clone = subschemas.clone();
133            obj.remove("oneOf");
134            obj.insert("anyOf".to_string(), subschemas_clone);
135        }
136
137        // Recursively process all nested objects and arrays
138        for (_, value) in obj.iter_mut() {
139            if let Value::Object(_) | Value::Array(_) = value {
140                adapt_to_json_schema_subset(value)?;
141            }
142        }
143    } else if let Value::Array(arr) = json {
144        for item in arr.iter_mut() {
145            adapt_to_json_schema_subset(item)?;
146        }
147    }
148    Ok(())
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use serde_json::json;
155
156    #[test]
157    fn test_transform_adds_type_when_missing() {
158        let mut json = json!({
159            "description": "A test field without type"
160        });
161
162        adapt_to_json_schema_subset(&mut json).unwrap();
163
164        assert_eq!(
165            json,
166            json!({
167                "description": "A test field without type",
168                "type": "string"
169            })
170        );
171
172        // Ensure that we do not add a type if it is an object
173        let mut json = json!({
174            "description": {
175                "value": "abc",
176                "type": "string"
177            }
178        });
179
180        adapt_to_json_schema_subset(&mut json).unwrap();
181
182        assert_eq!(
183            json,
184            json!({
185                "description": {
186                    "value": "abc",
187                    "type": "string"
188                }
189            })
190        );
191    }
192
193    #[test]
194    fn test_transform_removes_unsupported_keys() {
195        let mut json = json!({
196            "description": "A test field",
197            "type": "integer",
198            "format": "uint32",
199            "exclusiveMinimum": 0,
200            "exclusiveMaximum": 100,
201            "additionalProperties": false,
202            "optional": true
203        });
204
205        adapt_to_json_schema_subset(&mut json).unwrap();
206
207        assert_eq!(
208            json,
209            json!({
210                "description": "A test field",
211                "type": "integer"
212            })
213        );
214
215        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
216        let mut json = json!({
217            "description": "A test field",
218            "type": "integer",
219            "format": {},
220        });
221
222        adapt_to_json_schema_subset(&mut json).unwrap();
223
224        assert_eq!(
225            json,
226            json!({
227                "description": "A test field",
228                "type": "integer",
229                "format": {},
230            })
231        );
232    }
233
234    #[test]
235    fn test_transform_one_of_to_any_of() {
236        let mut json = json!({
237            "description": "A test field",
238            "oneOf": [
239                { "type": "string" },
240                { "type": "integer" }
241            ]
242        });
243
244        adapt_to_json_schema_subset(&mut json).unwrap();
245
246        assert_eq!(
247            json,
248            json!({
249                "description": "A test field",
250                "anyOf": [
251                    { "type": "string" },
252                    { "type": "integer" }
253                ]
254            })
255        );
256    }
257
258    #[test]
259    fn test_transform_nested_objects() {
260        let mut json = json!({
261            "type": "object",
262            "properties": {
263                "nested": {
264                    "oneOf": [
265                        { "type": "string" },
266                        { "type": "null" }
267                    ],
268                    "format": "email"
269                }
270            }
271        });
272
273        adapt_to_json_schema_subset(&mut json).unwrap();
274
275        assert_eq!(
276            json,
277            json!({
278                "type": "object",
279                "properties": {
280                    "nested": {
281                        "anyOf": [
282                            { "type": "string" },
283                            { "type": "null" }
284                        ]
285                    }
286                }
287            })
288        );
289    }
290
291    #[test]
292    fn test_transform_fails_if_unsupported_keys_exist() {
293        let mut json = json!({
294            "type": "object",
295            "properties": {
296                "$ref": "#/definitions/User",
297            }
298        });
299
300        assert!(adapt_to_json_schema_subset(&mut json).is_err());
301
302        let mut json = json!({
303            "type": "object",
304            "properties": {
305                "if": "...",
306            }
307        });
308
309        assert!(adapt_to_json_schema_subset(&mut json).is_err());
310
311        let mut json = json!({
312            "type": "object",
313            "properties": {
314                "then": "...",
315            }
316        });
317
318        assert!(adapt_to_json_schema_subset(&mut json).is_err());
319
320        let mut json = json!({
321            "type": "object",
322            "properties": {
323                "else": "...",
324            }
325        });
326
327        assert!(adapt_to_json_schema_subset(&mut json).is_err());
328    }
329
330    #[test]
331    fn test_preprocess_json_schema_adds_additional_properties() {
332        let mut json = json!({
333            "type": "object",
334            "properties": {
335                "name": {
336                    "type": "string"
337                }
338            }
339        });
340
341        preprocess_json_schema(&mut json).unwrap();
342
343        assert_eq!(
344            json,
345            json!({
346                "type": "object",
347                "properties": {
348                    "name": {
349                        "type": "string"
350                    }
351                },
352                "additionalProperties": false
353            })
354        );
355    }
356
357    #[test]
358    fn test_preprocess_json_schema_preserves_additional_properties() {
359        let mut json = json!({
360            "type": "object",
361            "properties": {
362                "name": {
363                    "type": "string"
364                }
365            },
366            "additionalProperties": true
367        });
368
369        preprocess_json_schema(&mut json).unwrap();
370
371        assert_eq!(
372            json,
373            json!({
374                "type": "object",
375                "properties": {
376                    "name": {
377                        "type": "string"
378                    }
379                },
380                "additionalProperties": true
381            })
382        );
383    }
384}