1use anyhow::Result;
2use schemars::{
3 JsonSchema, Schema,
4 generate::SchemaSettings,
5 transform::{Transform, transform_subschemas},
6};
7use serde_json::Value;
8
9/// Indicates the format used to define the input schema for a language model tool.
10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
11pub enum LanguageModelToolSchemaFormat {
12 /// A JSON schema, see https://json-schema.org
13 JsonSchema,
14 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
15 JsonSchemaSubset,
16}
17
18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
19 let mut generator = match format {
20 LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
21 .with(|settings| {
22 settings.meta_schema = None;
23 settings.inline_subschemas = true;
24 })
25 .into_generator(),
26 LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
27 .with(|settings| {
28 settings.meta_schema = None;
29 settings.inline_subschemas = true;
30 })
31 .with_transform(ToJsonSchemaSubsetTransform)
32 .into_generator(),
33 };
34 generator.root_schema_for::<T>()
35}
36
37#[derive(Debug, Clone)]
38struct ToJsonSchemaSubsetTransform;
39
40impl Transform for ToJsonSchemaSubsetTransform {
41 fn transform(&mut self, schema: &mut Schema) {
42 // Ensure that the type field is not an array, this happens when we use
43 // Option<T>, the type will be [T, "null"].
44 if let Some(type_field) = schema.get_mut("type")
45 && let Some(types) = type_field.as_array()
46 && let Some(first_type) = types.first()
47 {
48 *type_field = first_type.clone();
49 }
50
51 // oneOf is not supported, use anyOf instead
52 if let Some(one_of) = schema.remove("oneOf") {
53 schema.insert("anyOf".to_string(), one_of);
54 }
55
56 transform_subschemas(self, schema);
57 }
58}
59
60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
61///
62/// If the json cannot be made compatible with the specified format, an error is returned.
63pub fn adapt_schema_to_format(
64 json: &mut Value,
65 format: LanguageModelToolSchemaFormat,
66) -> Result<()> {
67 if let Value::Object(obj) = json {
68 obj.remove("$schema");
69 obj.remove("title");
70 obj.remove("description");
71 }
72
73 match format {
74 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
75 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
76 }
77}
78
79fn preprocess_json_schema(json: &mut Value) -> Result<()> {
80 // `additionalProperties` defaults to `false` unless explicitly specified.
81 // This prevents models from hallucinating tool parameters.
82 if let Value::Object(obj) = json
83 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
84 {
85 if !obj.contains_key("additionalProperties") {
86 obj.insert("additionalProperties".to_string(), Value::Bool(false));
87 }
88
89 // OpenAI API requires non-missing `properties`
90 if !obj.contains_key("properties") {
91 obj.insert("properties".to_string(), Value::Object(Default::default()));
92 }
93 }
94 Ok(())
95}
96
97/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
98fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
99 if let Value::Object(obj) = json {
100 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
101
102 for key in UNSUPPORTED_KEYS {
103 anyhow::ensure!(
104 !obj.contains_key(key),
105 "Schema cannot be made compatible because it contains \"{key}\""
106 );
107 }
108
109 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
110 ("format", |value| value.is_string()),
111 // Gemini doesn't support `additionalProperties` in any form (boolean or schema object)
112 ("additionalProperties", |_| true),
113 // Gemini doesn't support `propertyNames`
114 ("propertyNames", |_| true),
115 ("exclusiveMinimum", |value| value.is_number()),
116 ("exclusiveMaximum", |value| value.is_number()),
117 ("optional", |value| value.is_boolean()),
118 ];
119 for (key, predicate) in KEYS_TO_REMOVE {
120 if let Some(value) = obj.get(key)
121 && predicate(value)
122 {
123 obj.remove(key);
124 }
125 }
126
127 // If a type is not specified for an input parameter, add a default type
128 if matches!(obj.get("description"), Some(Value::String(_)))
129 && !obj.contains_key("type")
130 && !(obj.contains_key("anyOf")
131 || obj.contains_key("oneOf")
132 || obj.contains_key("allOf"))
133 {
134 obj.insert("type".to_string(), Value::String("string".to_string()));
135 }
136
137 // Handle oneOf -> anyOf conversion
138 if let Some(subschemas) = obj.get_mut("oneOf")
139 && subschemas.is_array()
140 {
141 let subschemas_clone = subschemas.clone();
142 obj.remove("oneOf");
143 obj.insert("anyOf".to_string(), subschemas_clone);
144 }
145
146 // Recursively process all nested objects and arrays
147 for (_, value) in obj.iter_mut() {
148 if let Value::Object(_) | Value::Array(_) = value {
149 adapt_to_json_schema_subset(value)?;
150 }
151 }
152 } else if let Value::Array(arr) = json {
153 for item in arr.iter_mut() {
154 adapt_to_json_schema_subset(item)?;
155 }
156 }
157 Ok(())
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use serde_json::json;
164
165 #[test]
166 fn test_transform_adds_type_when_missing() {
167 let mut json = json!({
168 "description": "A test field without type"
169 });
170
171 adapt_to_json_schema_subset(&mut json).unwrap();
172
173 assert_eq!(
174 json,
175 json!({
176 "description": "A test field without type",
177 "type": "string"
178 })
179 );
180
181 // Ensure that we do not add a type if it is an object
182 let mut json = json!({
183 "description": {
184 "value": "abc",
185 "type": "string"
186 }
187 });
188
189 adapt_to_json_schema_subset(&mut json).unwrap();
190
191 assert_eq!(
192 json,
193 json!({
194 "description": {
195 "value": "abc",
196 "type": "string"
197 }
198 })
199 );
200 }
201
202 #[test]
203 fn test_transform_removes_unsupported_keys() {
204 let mut json = json!({
205 "description": "A test field",
206 "type": "integer",
207 "format": "uint32",
208 "exclusiveMinimum": 0,
209 "exclusiveMaximum": 100,
210 "additionalProperties": false,
211 "optional": true
212 });
213
214 adapt_to_json_schema_subset(&mut json).unwrap();
215
216 assert_eq!(
217 json,
218 json!({
219 "description": "A test field",
220 "type": "integer"
221 })
222 );
223
224 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
225 let mut json = json!({
226 "description": "A test field",
227 "type": "integer",
228 "format": {},
229 });
230
231 adapt_to_json_schema_subset(&mut json).unwrap();
232
233 assert_eq!(
234 json,
235 json!({
236 "description": "A test field",
237 "type": "integer",
238 "format": {},
239 })
240 );
241
242 // additionalProperties as an object schema is also unsupported by Gemini
243 let mut json = json!({
244 "type": "object",
245 "properties": {
246 "name": { "type": "string" }
247 },
248 "additionalProperties": { "type": "string" },
249 "propertyNames": { "pattern": "^[A-Za-z]+$" }
250 });
251
252 adapt_to_json_schema_subset(&mut json).unwrap();
253
254 assert_eq!(
255 json,
256 json!({
257 "type": "object",
258 "properties": {
259 "name": { "type": "string" }
260 }
261 })
262 );
263 }
264
265 #[test]
266 fn test_transform_one_of_to_any_of() {
267 let mut json = json!({
268 "description": "A test field",
269 "oneOf": [
270 { "type": "string" },
271 { "type": "integer" }
272 ]
273 });
274
275 adapt_to_json_schema_subset(&mut json).unwrap();
276
277 assert_eq!(
278 json,
279 json!({
280 "description": "A test field",
281 "anyOf": [
282 { "type": "string" },
283 { "type": "integer" }
284 ]
285 })
286 );
287 }
288
289 #[test]
290 fn test_transform_nested_objects() {
291 let mut json = json!({
292 "type": "object",
293 "properties": {
294 "nested": {
295 "oneOf": [
296 { "type": "string" },
297 { "type": "null" }
298 ],
299 "format": "email"
300 }
301 }
302 });
303
304 adapt_to_json_schema_subset(&mut json).unwrap();
305
306 assert_eq!(
307 json,
308 json!({
309 "type": "object",
310 "properties": {
311 "nested": {
312 "anyOf": [
313 { "type": "string" },
314 { "type": "null" }
315 ]
316 }
317 }
318 })
319 );
320 }
321
322 #[test]
323 fn test_transform_fails_if_unsupported_keys_exist() {
324 let mut json = json!({
325 "type": "object",
326 "properties": {
327 "$ref": "#/definitions/User",
328 }
329 });
330
331 assert!(adapt_to_json_schema_subset(&mut json).is_err());
332
333 let mut json = json!({
334 "type": "object",
335 "properties": {
336 "if": "...",
337 }
338 });
339
340 assert!(adapt_to_json_schema_subset(&mut json).is_err());
341
342 let mut json = json!({
343 "type": "object",
344 "properties": {
345 "then": "...",
346 }
347 });
348
349 assert!(adapt_to_json_schema_subset(&mut json).is_err());
350
351 let mut json = json!({
352 "type": "object",
353 "properties": {
354 "else": "...",
355 }
356 });
357
358 assert!(adapt_to_json_schema_subset(&mut json).is_err());
359 }
360
361 #[test]
362 fn test_preprocess_json_schema_adds_additional_properties() {
363 let mut json = json!({
364 "type": "object",
365 "properties": {
366 "name": {
367 "type": "string"
368 }
369 }
370 });
371
372 preprocess_json_schema(&mut json).unwrap();
373
374 assert_eq!(
375 json,
376 json!({
377 "type": "object",
378 "properties": {
379 "name": {
380 "type": "string"
381 }
382 },
383 "additionalProperties": false
384 })
385 );
386 }
387
388 #[test]
389 fn test_preprocess_json_schema_preserves_additional_properties() {
390 let mut json = json!({
391 "type": "object",
392 "properties": {
393 "name": {
394 "type": "string"
395 }
396 },
397 "additionalProperties": true
398 });
399
400 preprocess_json_schema(&mut json).unwrap();
401
402 assert_eq!(
403 json,
404 json!({
405 "type": "object",
406 "properties": {
407 "name": {
408 "type": "string"
409 }
410 },
411 "additionalProperties": true
412 })
413 );
414 }
415}