1use anyhow::Result;
2use schemars::{
3 JsonSchema, Schema,
4 generate::SchemaSettings,
5 transform::{Transform, transform_subschemas},
6};
7use serde_json::Value;
8
9/// Indicates the format used to define the input schema for a language model tool.
10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
11pub enum LanguageModelToolSchemaFormat {
12 /// A JSON schema, see https://json-schema.org
13 JsonSchema,
14 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
15 JsonSchemaSubset,
16}
17
18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
19 let mut generator = match format {
20 LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
21 LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
22 .with(|settings| {
23 settings.meta_schema = None;
24 settings.inline_subschemas = true;
25 })
26 .with_transform(ToJsonSchemaSubsetTransform)
27 .into_generator(),
28 };
29 generator.root_schema_for::<T>()
30}
31
32#[derive(Debug, Clone)]
33struct ToJsonSchemaSubsetTransform;
34
35impl Transform for ToJsonSchemaSubsetTransform {
36 fn transform(&mut self, schema: &mut Schema) {
37 // Ensure that the type field is not an array, this happens when we use
38 // Option<T>, the type will be [T, "null"].
39 if let Some(type_field) = schema.get_mut("type")
40 && let Some(types) = type_field.as_array()
41 && let Some(first_type) = types.first()
42 {
43 *type_field = first_type.clone();
44 }
45
46 // oneOf is not supported, use anyOf instead
47 if let Some(one_of) = schema.remove("oneOf") {
48 schema.insert("anyOf".to_string(), one_of);
49 }
50
51 transform_subschemas(self, schema);
52 }
53}
54
55/// Tries to adapt a JSON schema representation to be compatible with the specified format.
56///
57/// If the json cannot be made compatible with the specified format, an error is returned.
58pub fn adapt_schema_to_format(
59 json: &mut Value,
60 format: LanguageModelToolSchemaFormat,
61) -> Result<()> {
62 if let Value::Object(obj) = json {
63 obj.remove("$schema");
64 obj.remove("title");
65 }
66
67 match format {
68 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
69 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
70 }
71}
72
73fn preprocess_json_schema(json: &mut Value) -> Result<()> {
74 // `additionalProperties` defaults to `false` unless explicitly specified.
75 // This prevents models from hallucinating tool parameters.
76 if let Value::Object(obj) = json
77 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
78 {
79 if !obj.contains_key("additionalProperties") {
80 obj.insert("additionalProperties".to_string(), Value::Bool(false));
81 }
82
83 // OpenAI API requires non-missing `properties`
84 if !obj.contains_key("properties") {
85 obj.insert("properties".to_string(), Value::Object(Default::default()));
86 }
87 }
88 Ok(())
89}
90
91/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
92fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
93 if let Value::Object(obj) = json {
94 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
95
96 for key in UNSUPPORTED_KEYS {
97 anyhow::ensure!(
98 !obj.contains_key(key),
99 "Schema cannot be made compatible because it contains \"{key}\""
100 );
101 }
102
103 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
104 ("format", |value| value.is_string()),
105 ("additionalProperties", |value| value.is_boolean()),
106 ("exclusiveMinimum", |value| value.is_number()),
107 ("exclusiveMaximum", |value| value.is_number()),
108 ("optional", |value| value.is_boolean()),
109 ];
110 for (key, predicate) in KEYS_TO_REMOVE {
111 if let Some(value) = obj.get(key)
112 && predicate(value)
113 {
114 obj.remove(key);
115 }
116 }
117
118 // If a type is not specified for an input parameter, add a default type
119 if matches!(obj.get("description"), Some(Value::String(_)))
120 && !obj.contains_key("type")
121 && !(obj.contains_key("anyOf")
122 || obj.contains_key("oneOf")
123 || obj.contains_key("allOf"))
124 {
125 obj.insert("type".to_string(), Value::String("string".to_string()));
126 }
127
128 // Handle oneOf -> anyOf conversion
129 if let Some(subschemas) = obj.get_mut("oneOf")
130 && subschemas.is_array()
131 {
132 let subschemas_clone = subschemas.clone();
133 obj.remove("oneOf");
134 obj.insert("anyOf".to_string(), subschemas_clone);
135 }
136
137 // Recursively process all nested objects and arrays
138 for (_, value) in obj.iter_mut() {
139 if let Value::Object(_) | Value::Array(_) = value {
140 adapt_to_json_schema_subset(value)?;
141 }
142 }
143 } else if let Value::Array(arr) = json {
144 for item in arr.iter_mut() {
145 adapt_to_json_schema_subset(item)?;
146 }
147 }
148 Ok(())
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use serde_json::json;
155
156 #[test]
157 fn test_transform_adds_type_when_missing() {
158 let mut json = json!({
159 "description": "A test field without type"
160 });
161
162 adapt_to_json_schema_subset(&mut json).unwrap();
163
164 assert_eq!(
165 json,
166 json!({
167 "description": "A test field without type",
168 "type": "string"
169 })
170 );
171
172 // Ensure that we do not add a type if it is an object
173 let mut json = json!({
174 "description": {
175 "value": "abc",
176 "type": "string"
177 }
178 });
179
180 adapt_to_json_schema_subset(&mut json).unwrap();
181
182 assert_eq!(
183 json,
184 json!({
185 "description": {
186 "value": "abc",
187 "type": "string"
188 }
189 })
190 );
191 }
192
193 #[test]
194 fn test_transform_removes_unsupported_keys() {
195 let mut json = json!({
196 "description": "A test field",
197 "type": "integer",
198 "format": "uint32",
199 "exclusiveMinimum": 0,
200 "exclusiveMaximum": 100,
201 "additionalProperties": false,
202 "optional": true
203 });
204
205 adapt_to_json_schema_subset(&mut json).unwrap();
206
207 assert_eq!(
208 json,
209 json!({
210 "description": "A test field",
211 "type": "integer"
212 })
213 );
214
215 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
216 let mut json = json!({
217 "description": "A test field",
218 "type": "integer",
219 "format": {},
220 });
221
222 adapt_to_json_schema_subset(&mut json).unwrap();
223
224 assert_eq!(
225 json,
226 json!({
227 "description": "A test field",
228 "type": "integer",
229 "format": {},
230 })
231 );
232 }
233
234 #[test]
235 fn test_transform_one_of_to_any_of() {
236 let mut json = json!({
237 "description": "A test field",
238 "oneOf": [
239 { "type": "string" },
240 { "type": "integer" }
241 ]
242 });
243
244 adapt_to_json_schema_subset(&mut json).unwrap();
245
246 assert_eq!(
247 json,
248 json!({
249 "description": "A test field",
250 "anyOf": [
251 { "type": "string" },
252 { "type": "integer" }
253 ]
254 })
255 );
256 }
257
258 #[test]
259 fn test_transform_nested_objects() {
260 let mut json = json!({
261 "type": "object",
262 "properties": {
263 "nested": {
264 "oneOf": [
265 { "type": "string" },
266 { "type": "null" }
267 ],
268 "format": "email"
269 }
270 }
271 });
272
273 adapt_to_json_schema_subset(&mut json).unwrap();
274
275 assert_eq!(
276 json,
277 json!({
278 "type": "object",
279 "properties": {
280 "nested": {
281 "anyOf": [
282 { "type": "string" },
283 { "type": "null" }
284 ]
285 }
286 }
287 })
288 );
289 }
290
291 #[test]
292 fn test_transform_fails_if_unsupported_keys_exist() {
293 let mut json = json!({
294 "type": "object",
295 "properties": {
296 "$ref": "#/definitions/User",
297 }
298 });
299
300 assert!(adapt_to_json_schema_subset(&mut json).is_err());
301
302 let mut json = json!({
303 "type": "object",
304 "properties": {
305 "if": "...",
306 }
307 });
308
309 assert!(adapt_to_json_schema_subset(&mut json).is_err());
310
311 let mut json = json!({
312 "type": "object",
313 "properties": {
314 "then": "...",
315 }
316 });
317
318 assert!(adapt_to_json_schema_subset(&mut json).is_err());
319
320 let mut json = json!({
321 "type": "object",
322 "properties": {
323 "else": "...",
324 }
325 });
326
327 assert!(adapt_to_json_schema_subset(&mut json).is_err());
328 }
329
330 #[test]
331 fn test_preprocess_json_schema_adds_additional_properties() {
332 let mut json = json!({
333 "type": "object",
334 "properties": {
335 "name": {
336 "type": "string"
337 }
338 }
339 });
340
341 preprocess_json_schema(&mut json).unwrap();
342
343 assert_eq!(
344 json,
345 json!({
346 "type": "object",
347 "properties": {
348 "name": {
349 "type": "string"
350 }
351 },
352 "additionalProperties": false
353 })
354 );
355 }
356
357 #[test]
358 fn test_preprocess_json_schema_preserves_additional_properties() {
359 let mut json = json!({
360 "type": "object",
361 "properties": {
362 "name": {
363 "type": "string"
364 }
365 },
366 "additionalProperties": true
367 });
368
369 preprocess_json_schema(&mut json).unwrap();
370
371 assert_eq!(
372 json,
373 json!({
374 "type": "object",
375 "properties": {
376 "name": {
377 "type": "string"
378 }
379 },
380 "additionalProperties": true
381 })
382 );
383 }
384}