1use anyhow::Result;
2use schemars::{
3 JsonSchema, Schema,
4 generate::SchemaSettings,
5 transform::{Transform, transform_subschemas},
6};
7use serde_json::Value;
8
9/// Indicates the format used to define the input schema for a language model tool.
10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
11pub enum LanguageModelToolSchemaFormat {
12 /// A JSON schema, see https://json-schema.org
13 JsonSchema,
14 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
15 JsonSchemaSubset,
16}
17
18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
19 let mut generator = match format {
20 LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
21 .with(|settings| {
22 settings.meta_schema = None;
23 settings.inline_subschemas = true;
24 })
25 .into_generator(),
26 LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
27 .with(|settings| {
28 settings.meta_schema = None;
29 settings.inline_subschemas = true;
30 })
31 .with_transform(ToJsonSchemaSubsetTransform)
32 .into_generator(),
33 };
34 generator.root_schema_for::<T>()
35}
36
37#[derive(Debug, Clone)]
38struct ToJsonSchemaSubsetTransform;
39
40impl Transform for ToJsonSchemaSubsetTransform {
41 fn transform(&mut self, schema: &mut Schema) {
42 // Ensure that the type field is not an array, this happens when we use
43 // Option<T>, the type will be [T, "null"].
44 if let Some(type_field) = schema.get_mut("type")
45 && let Some(types) = type_field.as_array()
46 && let Some(first_type) = types.first()
47 {
48 *type_field = first_type.clone();
49 }
50
51 // oneOf is not supported, use anyOf instead
52 if let Some(one_of) = schema.remove("oneOf") {
53 schema.insert("anyOf".to_string(), one_of);
54 }
55
56 transform_subschemas(self, schema);
57 }
58}
59
60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
61///
62/// If the json cannot be made compatible with the specified format, an error is returned.
63pub fn adapt_schema_to_format(
64 json: &mut Value,
65 format: LanguageModelToolSchemaFormat,
66) -> Result<()> {
67 if let Value::Object(obj) = json {
68 obj.remove("$schema");
69 obj.remove("title");
70 }
71
72 match format {
73 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
74 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
75 }
76}
77
78fn preprocess_json_schema(json: &mut Value) -> Result<()> {
79 // `additionalProperties` defaults to `false` unless explicitly specified.
80 // This prevents models from hallucinating tool parameters.
81 if let Value::Object(obj) = json
82 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
83 {
84 if !obj.contains_key("additionalProperties") {
85 obj.insert("additionalProperties".to_string(), Value::Bool(false));
86 }
87
88 // OpenAI API requires non-missing `properties`
89 if !obj.contains_key("properties") {
90 obj.insert("properties".to_string(), Value::Object(Default::default()));
91 }
92 }
93 Ok(())
94}
95
96/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
97fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
98 if let Value::Object(obj) = json {
99 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
100
101 for key in UNSUPPORTED_KEYS {
102 anyhow::ensure!(
103 !obj.contains_key(key),
104 "Schema cannot be made compatible because it contains \"{key}\""
105 );
106 }
107
108 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
109 ("format", |value| value.is_string()),
110 ("additionalProperties", |value| value.is_boolean()),
111 ("exclusiveMinimum", |value| value.is_number()),
112 ("exclusiveMaximum", |value| value.is_number()),
113 ("optional", |value| value.is_boolean()),
114 ];
115 for (key, predicate) in KEYS_TO_REMOVE {
116 if let Some(value) = obj.get(key)
117 && predicate(value)
118 {
119 obj.remove(key);
120 }
121 }
122
123 // If a type is not specified for an input parameter, add a default type
124 if matches!(obj.get("description"), Some(Value::String(_)))
125 && !obj.contains_key("type")
126 && !(obj.contains_key("anyOf")
127 || obj.contains_key("oneOf")
128 || obj.contains_key("allOf"))
129 {
130 obj.insert("type".to_string(), Value::String("string".to_string()));
131 }
132
133 // Handle oneOf -> anyOf conversion
134 if let Some(subschemas) = obj.get_mut("oneOf")
135 && subschemas.is_array()
136 {
137 let subschemas_clone = subschemas.clone();
138 obj.remove("oneOf");
139 obj.insert("anyOf".to_string(), subschemas_clone);
140 }
141
142 // Recursively process all nested objects and arrays
143 for (_, value) in obj.iter_mut() {
144 if let Value::Object(_) | Value::Array(_) = value {
145 adapt_to_json_schema_subset(value)?;
146 }
147 }
148 } else if let Value::Array(arr) = json {
149 for item in arr.iter_mut() {
150 adapt_to_json_schema_subset(item)?;
151 }
152 }
153 Ok(())
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use serde_json::json;
160
161 #[test]
162 fn test_transform_adds_type_when_missing() {
163 let mut json = json!({
164 "description": "A test field without type"
165 });
166
167 adapt_to_json_schema_subset(&mut json).unwrap();
168
169 assert_eq!(
170 json,
171 json!({
172 "description": "A test field without type",
173 "type": "string"
174 })
175 );
176
177 // Ensure that we do not add a type if it is an object
178 let mut json = json!({
179 "description": {
180 "value": "abc",
181 "type": "string"
182 }
183 });
184
185 adapt_to_json_schema_subset(&mut json).unwrap();
186
187 assert_eq!(
188 json,
189 json!({
190 "description": {
191 "value": "abc",
192 "type": "string"
193 }
194 })
195 );
196 }
197
198 #[test]
199 fn test_transform_removes_unsupported_keys() {
200 let mut json = json!({
201 "description": "A test field",
202 "type": "integer",
203 "format": "uint32",
204 "exclusiveMinimum": 0,
205 "exclusiveMaximum": 100,
206 "additionalProperties": false,
207 "optional": true
208 });
209
210 adapt_to_json_schema_subset(&mut json).unwrap();
211
212 assert_eq!(
213 json,
214 json!({
215 "description": "A test field",
216 "type": "integer"
217 })
218 );
219
220 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
221 let mut json = json!({
222 "description": "A test field",
223 "type": "integer",
224 "format": {},
225 });
226
227 adapt_to_json_schema_subset(&mut json).unwrap();
228
229 assert_eq!(
230 json,
231 json!({
232 "description": "A test field",
233 "type": "integer",
234 "format": {},
235 })
236 );
237 }
238
239 #[test]
240 fn test_transform_one_of_to_any_of() {
241 let mut json = json!({
242 "description": "A test field",
243 "oneOf": [
244 { "type": "string" },
245 { "type": "integer" }
246 ]
247 });
248
249 adapt_to_json_schema_subset(&mut json).unwrap();
250
251 assert_eq!(
252 json,
253 json!({
254 "description": "A test field",
255 "anyOf": [
256 { "type": "string" },
257 { "type": "integer" }
258 ]
259 })
260 );
261 }
262
263 #[test]
264 fn test_transform_nested_objects() {
265 let mut json = json!({
266 "type": "object",
267 "properties": {
268 "nested": {
269 "oneOf": [
270 { "type": "string" },
271 { "type": "null" }
272 ],
273 "format": "email"
274 }
275 }
276 });
277
278 adapt_to_json_schema_subset(&mut json).unwrap();
279
280 assert_eq!(
281 json,
282 json!({
283 "type": "object",
284 "properties": {
285 "nested": {
286 "anyOf": [
287 { "type": "string" },
288 { "type": "null" }
289 ]
290 }
291 }
292 })
293 );
294 }
295
296 #[test]
297 fn test_transform_fails_if_unsupported_keys_exist() {
298 let mut json = json!({
299 "type": "object",
300 "properties": {
301 "$ref": "#/definitions/User",
302 }
303 });
304
305 assert!(adapt_to_json_schema_subset(&mut json).is_err());
306
307 let mut json = json!({
308 "type": "object",
309 "properties": {
310 "if": "...",
311 }
312 });
313
314 assert!(adapt_to_json_schema_subset(&mut json).is_err());
315
316 let mut json = json!({
317 "type": "object",
318 "properties": {
319 "then": "...",
320 }
321 });
322
323 assert!(adapt_to_json_schema_subset(&mut json).is_err());
324
325 let mut json = json!({
326 "type": "object",
327 "properties": {
328 "else": "...",
329 }
330 });
331
332 assert!(adapt_to_json_schema_subset(&mut json).is_err());
333 }
334
335 #[test]
336 fn test_preprocess_json_schema_adds_additional_properties() {
337 let mut json = json!({
338 "type": "object",
339 "properties": {
340 "name": {
341 "type": "string"
342 }
343 }
344 });
345
346 preprocess_json_schema(&mut json).unwrap();
347
348 assert_eq!(
349 json,
350 json!({
351 "type": "object",
352 "properties": {
353 "name": {
354 "type": "string"
355 }
356 },
357 "additionalProperties": false
358 })
359 );
360 }
361
362 #[test]
363 fn test_preprocess_json_schema_preserves_additional_properties() {
364 let mut json = json!({
365 "type": "object",
366 "properties": {
367 "name": {
368 "type": "string"
369 }
370 },
371 "additionalProperties": true
372 });
373
374 preprocess_json_schema(&mut json).unwrap();
375
376 assert_eq!(
377 json,
378 json!({
379 "type": "object",
380 "properties": {
381 "name": {
382 "type": "string"
383 }
384 },
385 "additionalProperties": true
386 })
387 );
388 }
389}