tool_schema.rs

  1use anyhow::Result;
  2use language_model::LanguageModelToolSchemaFormat;
  3use schemars::{
  4    JsonSchema, Schema,
  5    generate::SchemaSettings,
  6    transform::{Transform, transform_subschemas},
  7};
  8use serde_json::Value;
  9
 10pub(crate) fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
 11    let mut generator = match format {
 12        LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
 13        LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
 14            .with(|settings| {
 15                settings.meta_schema = None;
 16                settings.inline_subschemas = true;
 17            })
 18            .with_transform(ToJsonSchemaSubsetTransform)
 19            .into_generator(),
 20    };
 21    generator.root_schema_for::<T>()
 22}
 23
 24#[derive(Debug, Clone)]
 25struct ToJsonSchemaSubsetTransform;
 26
 27impl Transform for ToJsonSchemaSubsetTransform {
 28    fn transform(&mut self, schema: &mut Schema) {
 29        // Ensure that the type field is not an array, this happens when we use
 30        // Option<T>, the type will be [T, "null"].
 31        if let Some(type_field) = schema.get_mut("type")
 32            && let Some(types) = type_field.as_array()
 33            && let Some(first_type) = types.first()
 34        {
 35            *type_field = first_type.clone();
 36        }
 37
 38        // oneOf is not supported, use anyOf instead
 39        if let Some(one_of) = schema.remove("oneOf") {
 40            schema.insert("anyOf".to_string(), one_of);
 41        }
 42
 43        transform_subschemas(self, schema);
 44    }
 45}
 46
 47/// Tries to adapt a JSON schema representation to be compatible with the specified format.
 48///
 49/// If the json cannot be made compatible with the specified format, an error is returned.
 50pub fn adapt_schema_to_format(
 51    json: &mut Value,
 52    format: LanguageModelToolSchemaFormat,
 53) -> Result<()> {
 54    if let Value::Object(obj) = json {
 55        obj.remove("$schema");
 56        obj.remove("title");
 57    }
 58
 59    match format {
 60        LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
 61        LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
 62    }
 63}
 64
 65fn preprocess_json_schema(json: &mut Value) -> Result<()> {
 66    // `additionalProperties` defaults to `false` unless explicitly specified.
 67    // This prevents models from hallucinating tool parameters.
 68    if let Value::Object(obj) = json
 69        && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
 70    {
 71        if !obj.contains_key("additionalProperties") {
 72            obj.insert("additionalProperties".to_string(), Value::Bool(false));
 73        }
 74
 75        // OpenAI API requires non-missing `properties`
 76        if !obj.contains_key("properties") {
 77            obj.insert("properties".to_string(), Value::Object(Default::default()));
 78        }
 79    }
 80    Ok(())
 81}
 82
 83/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
 84fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
 85    if let Value::Object(obj) = json {
 86        const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
 87
 88        for key in UNSUPPORTED_KEYS {
 89            anyhow::ensure!(
 90                !obj.contains_key(key),
 91                "Schema cannot be made compatible because it contains \"{key}\""
 92            );
 93        }
 94
 95        const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
 96            ("format", |value| value.is_string()),
 97            ("additionalProperties", |value| value.is_boolean()),
 98            ("exclusiveMinimum", |value| value.is_number()),
 99            ("exclusiveMaximum", |value| value.is_number()),
100            ("optional", |value| value.is_boolean()),
101        ];
102        for (key, predicate) in KEYS_TO_REMOVE {
103            if let Some(value) = obj.get(key)
104                && predicate(value)
105            {
106                obj.remove(key);
107            }
108        }
109
110        // If a type is not specified for an input parameter, add a default type
111        if matches!(obj.get("description"), Some(Value::String(_)))
112            && !obj.contains_key("type")
113            && !(obj.contains_key("anyOf")
114                || obj.contains_key("oneOf")
115                || obj.contains_key("allOf"))
116        {
117            obj.insert("type".to_string(), Value::String("string".to_string()));
118        }
119
120        // Handle oneOf -> anyOf conversion
121        if let Some(subschemas) = obj.get_mut("oneOf")
122            && subschemas.is_array()
123        {
124            let subschemas_clone = subschemas.clone();
125            obj.remove("oneOf");
126            obj.insert("anyOf".to_string(), subschemas_clone);
127        }
128
129        // Recursively process all nested objects and arrays
130        for (_, value) in obj.iter_mut() {
131            if let Value::Object(_) | Value::Array(_) = value {
132                adapt_to_json_schema_subset(value)?;
133            }
134        }
135    } else if let Value::Array(arr) = json {
136        for item in arr.iter_mut() {
137            adapt_to_json_schema_subset(item)?;
138        }
139    }
140    Ok(())
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use serde_json::json;
147
148    #[test]
149    fn test_transform_adds_type_when_missing() {
150        let mut json = json!({
151            "description": "A test field without type"
152        });
153
154        adapt_to_json_schema_subset(&mut json).unwrap();
155
156        assert_eq!(
157            json,
158            json!({
159                "description": "A test field without type",
160                "type": "string"
161            })
162        );
163
164        // Ensure that we do not add a type if it is an object
165        let mut json = json!({
166            "description": {
167                "value": "abc",
168                "type": "string"
169            }
170        });
171
172        adapt_to_json_schema_subset(&mut json).unwrap();
173
174        assert_eq!(
175            json,
176            json!({
177                "description": {
178                    "value": "abc",
179                    "type": "string"
180                }
181            })
182        );
183    }
184
185    #[test]
186    fn test_transform_removes_unsupported_keys() {
187        let mut json = json!({
188            "description": "A test field",
189            "type": "integer",
190            "format": "uint32",
191            "exclusiveMinimum": 0,
192            "exclusiveMaximum": 100,
193            "additionalProperties": false,
194            "optional": true
195        });
196
197        adapt_to_json_schema_subset(&mut json).unwrap();
198
199        assert_eq!(
200            json,
201            json!({
202                "description": "A test field",
203                "type": "integer"
204            })
205        );
206
207        // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
208        let mut json = json!({
209            "description": "A test field",
210            "type": "integer",
211            "format": {},
212        });
213
214        adapt_to_json_schema_subset(&mut json).unwrap();
215
216        assert_eq!(
217            json,
218            json!({
219                "description": "A test field",
220                "type": "integer",
221                "format": {},
222            })
223        );
224    }
225
226    #[test]
227    fn test_transform_one_of_to_any_of() {
228        let mut json = json!({
229            "description": "A test field",
230            "oneOf": [
231                { "type": "string" },
232                { "type": "integer" }
233            ]
234        });
235
236        adapt_to_json_schema_subset(&mut json).unwrap();
237
238        assert_eq!(
239            json,
240            json!({
241                "description": "A test field",
242                "anyOf": [
243                    { "type": "string" },
244                    { "type": "integer" }
245                ]
246            })
247        );
248    }
249
250    #[test]
251    fn test_transform_nested_objects() {
252        let mut json = json!({
253            "type": "object",
254            "properties": {
255                "nested": {
256                    "oneOf": [
257                        { "type": "string" },
258                        { "type": "null" }
259                    ],
260                    "format": "email"
261                }
262            }
263        });
264
265        adapt_to_json_schema_subset(&mut json).unwrap();
266
267        assert_eq!(
268            json,
269            json!({
270                "type": "object",
271                "properties": {
272                    "nested": {
273                        "anyOf": [
274                            { "type": "string" },
275                            { "type": "null" }
276                        ]
277                    }
278                }
279            })
280        );
281    }
282
283    #[test]
284    fn test_transform_fails_if_unsupported_keys_exist() {
285        let mut json = json!({
286            "type": "object",
287            "properties": {
288                "$ref": "#/definitions/User",
289            }
290        });
291
292        assert!(adapt_to_json_schema_subset(&mut json).is_err());
293
294        let mut json = json!({
295            "type": "object",
296            "properties": {
297                "if": "...",
298            }
299        });
300
301        assert!(adapt_to_json_schema_subset(&mut json).is_err());
302
303        let mut json = json!({
304            "type": "object",
305            "properties": {
306                "then": "...",
307            }
308        });
309
310        assert!(adapt_to_json_schema_subset(&mut json).is_err());
311
312        let mut json = json!({
313            "type": "object",
314            "properties": {
315                "else": "...",
316            }
317        });
318
319        assert!(adapt_to_json_schema_subset(&mut json).is_err());
320    }
321
322    #[test]
323    fn test_preprocess_json_schema_adds_additional_properties() {
324        let mut json = json!({
325            "type": "object",
326            "properties": {
327                "name": {
328                    "type": "string"
329                }
330            }
331        });
332
333        preprocess_json_schema(&mut json).unwrap();
334
335        assert_eq!(
336            json,
337            json!({
338                "type": "object",
339                "properties": {
340                    "name": {
341                        "type": "string"
342                    }
343                },
344                "additionalProperties": false
345            })
346        );
347    }
348
349    #[test]
350    fn test_preprocess_json_schema_preserves_additional_properties() {
351        let mut json = json!({
352            "type": "object",
353            "properties": {
354                "name": {
355                    "type": "string"
356                }
357            },
358            "additionalProperties": true
359        });
360
361        preprocess_json_schema(&mut json).unwrap();
362
363        assert_eq!(
364            json,
365            json!({
366                "type": "object",
367                "properties": {
368                    "name": {
369                        "type": "string"
370                    }
371                },
372                "additionalProperties": true
373            })
374        );
375    }
376}