1use anyhow::Result;
2use schemars::{
3 JsonSchema, Schema,
4 generate::SchemaSettings,
5 transform::{Transform, transform_subschemas},
6};
7use serde_json::Value;
8
9/// Indicates the format used to define the input schema for a language model tool.
10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
11pub enum LanguageModelToolSchemaFormat {
12 /// A JSON schema, see https://json-schema.org
13 JsonSchema,
14 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
15 JsonSchemaSubset,
16}
17
18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
19 let mut generator = match format {
20 LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
21 .with(|settings| {
22 settings.meta_schema = None;
23 settings.inline_subschemas = true;
24 })
25 .into_generator(),
26 LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
27 .with(|settings| {
28 settings.meta_schema = None;
29 settings.inline_subschemas = true;
30 })
31 .with_transform(ToJsonSchemaSubsetTransform)
32 .into_generator(),
33 };
34 generator.root_schema_for::<T>()
35}
36
37#[derive(Debug, Clone)]
38struct ToJsonSchemaSubsetTransform;
39
40impl Transform for ToJsonSchemaSubsetTransform {
41 fn transform(&mut self, schema: &mut Schema) {
42 // Ensure that the type field is not an array, this happens when we use
43 // Option<T>, the type will be [T, "null"].
44 if let Some(type_field) = schema.get_mut("type")
45 && let Some(types) = type_field.as_array()
46 && let Some(first_type) = types.first()
47 {
48 *type_field = first_type.clone();
49 }
50
51 // oneOf is not supported, use anyOf instead
52 if let Some(one_of) = schema.remove("oneOf") {
53 schema.insert("anyOf".to_string(), one_of);
54 }
55
56 transform_subschemas(self, schema);
57 }
58}
59
60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
61///
62/// If the json cannot be made compatible with the specified format, an error is returned.
63pub fn adapt_schema_to_format(
64 json: &mut Value,
65 format: LanguageModelToolSchemaFormat,
66) -> Result<()> {
67 if let Value::Object(obj) = json {
68 obj.remove("$schema");
69 obj.remove("title");
70 obj.remove("description");
71 }
72
73 match format {
74 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
75 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
76 }
77}
78
79fn preprocess_json_schema(json: &mut Value) -> Result<()> {
80 if let Value::Object(obj) = json
81 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
82 {
83 if !obj.contains_key("additionalProperties") {
84 obj.insert("additionalProperties".to_string(), Value::Bool(false));
85 }
86
87 if !obj.contains_key("properties") {
88 obj.insert("properties".to_string(), Value::Object(Default::default()));
89 }
90 }
91 Ok(())
92}
93
94fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
95 if let Value::Object(obj) = json {
96 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
97
98 for key in UNSUPPORTED_KEYS {
99 anyhow::ensure!(
100 !obj.contains_key(key),
101 "Schema cannot be made compatible because it contains \"{key}\""
102 );
103 }
104
105 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
106 ("format", |value| value.is_string()),
107 ("additionalProperties", |_| true),
108 ("propertyNames", |_| true),
109 ("exclusiveMinimum", |value| value.is_number()),
110 ("exclusiveMaximum", |value| value.is_number()),
111 ("optional", |value| value.is_boolean()),
112 ];
113 for (key, predicate) in KEYS_TO_REMOVE {
114 if let Some(value) = obj.get(key)
115 && predicate(value)
116 {
117 obj.remove(key);
118 }
119 }
120
121 // Ensure that the type field is not an array. This can happen with MCP tool
122 // schemas that use multiple types (e.g. `["string", "number"]` or `["string", "null"]`).
123 if let Some(type_value) = obj.get_mut("type")
124 && let Some(types) = type_value.as_array()
125 && let Some(first_type) = types.first().cloned()
126 {
127 *type_value = first_type;
128 }
129
130 if matches!(obj.get("description"), Some(Value::String(_)))
131 && !obj.contains_key("type")
132 && !(obj.contains_key("anyOf")
133 || obj.contains_key("oneOf")
134 || obj.contains_key("allOf"))
135 {
136 obj.insert("type".to_string(), Value::String("string".to_string()));
137 }
138
139 if let Some(subschemas) = obj.get_mut("oneOf")
140 && subschemas.is_array()
141 {
142 let subschemas_clone = subschemas.clone();
143 obj.remove("oneOf");
144 obj.insert("anyOf".to_string(), subschemas_clone);
145 }
146
147 for (_, value) in obj.iter_mut() {
148 if let Value::Object(_) | Value::Array(_) = value {
149 adapt_to_json_schema_subset(value)?;
150 }
151 }
152 } else if let Value::Array(arr) = json {
153 for item in arr.iter_mut() {
154 adapt_to_json_schema_subset(item)?;
155 }
156 }
157 Ok(())
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use serde_json::json;
164
165 #[test]
166 fn test_transform_adds_type_when_missing() {
167 let mut json = json!({
168 "description": "A test field without type"
169 });
170
171 adapt_to_json_schema_subset(&mut json).unwrap();
172
173 assert_eq!(
174 json,
175 json!({
176 "description": "A test field without type",
177 "type": "string"
178 })
179 );
180
181 let mut json = json!({
182 "description": {
183 "value": "abc",
184 "type": "string"
185 }
186 });
187
188 adapt_to_json_schema_subset(&mut json).unwrap();
189
190 assert_eq!(
191 json,
192 json!({
193 "description": {
194 "value": "abc",
195 "type": "string"
196 }
197 })
198 );
199 }
200
201 #[test]
202 fn test_transform_removes_unsupported_keys() {
203 let mut json = json!({
204 "description": "A test field",
205 "type": "integer",
206 "format": "uint32",
207 "exclusiveMinimum": 0,
208 "exclusiveMaximum": 100,
209 "additionalProperties": false,
210 "optional": true
211 });
212
213 adapt_to_json_schema_subset(&mut json).unwrap();
214
215 assert_eq!(
216 json,
217 json!({
218 "description": "A test field",
219 "type": "integer"
220 })
221 );
222
223 let mut json = json!({
224 "description": "A test field",
225 "type": "integer",
226 "format": {},
227 });
228
229 adapt_to_json_schema_subset(&mut json).unwrap();
230
231 assert_eq!(
232 json,
233 json!({
234 "description": "A test field",
235 "type": "integer",
236 "format": {},
237 })
238 );
239
240 let mut json = json!({
241 "type": "object",
242 "properties": {
243 "name": { "type": "string" }
244 },
245 "additionalProperties": { "type": "string" },
246 "propertyNames": { "pattern": "^[A-Za-z]+$" }
247 });
248
249 adapt_to_json_schema_subset(&mut json).unwrap();
250
251 assert_eq!(
252 json,
253 json!({
254 "type": "object",
255 "properties": {
256 "name": { "type": "string" }
257 }
258 })
259 );
260 }
261
262 #[test]
263 fn test_transform_one_of_to_any_of() {
264 let mut json = json!({
265 "description": "A test field",
266 "oneOf": [
267 { "type": "string" },
268 { "type": "integer" }
269 ]
270 });
271
272 adapt_to_json_schema_subset(&mut json).unwrap();
273
274 assert_eq!(
275 json,
276 json!({
277 "description": "A test field",
278 "anyOf": [
279 { "type": "string" },
280 { "type": "integer" }
281 ]
282 })
283 );
284 }
285
286 #[test]
287 fn test_transform_nested_objects() {
288 let mut json = json!({
289 "type": "object",
290 "properties": {
291 "nested": {
292 "oneOf": [
293 { "type": "string" },
294 { "type": "null" }
295 ],
296 "format": "email"
297 }
298 }
299 });
300
301 adapt_to_json_schema_subset(&mut json).unwrap();
302
303 assert_eq!(
304 json,
305 json!({
306 "type": "object",
307 "properties": {
308 "nested": {
309 "anyOf": [
310 { "type": "string" },
311 { "type": "null" }
312 ]
313 }
314 }
315 })
316 );
317 }
318
319 #[test]
320 fn test_transform_array_type_to_single_type() {
321 let mut json = json!({
322 "type": "object",
323 "properties": {
324 "projectSlugOrId": {
325 "type": ["string", "number"],
326 "description": "Project slug or numeric ID"
327 },
328 "optionalName": {
329 "type": ["string", "null"],
330 "description": "An optional name"
331 }
332 }
333 });
334
335 adapt_to_json_schema_subset(&mut json).unwrap();
336
337 assert_eq!(
338 json,
339 json!({
340 "type": "object",
341 "properties": {
342 "projectSlugOrId": {
343 "type": "string",
344 "description": "Project slug or numeric ID"
345 },
346 "optionalName": {
347 "type": "string",
348 "description": "An optional name"
349 }
350 }
351 })
352 );
353 }
354
355 #[test]
356 fn test_transform_fails_if_unsupported_keys_exist() {
357 let mut json = json!({
358 "type": "object",
359 "properties": {
360 "$ref": "#/definitions/User",
361 }
362 });
363
364 assert!(adapt_to_json_schema_subset(&mut json).is_err());
365
366 let mut json = json!({
367 "type": "object",
368 "properties": {
369 "if": "...",
370 }
371 });
372
373 assert!(adapt_to_json_schema_subset(&mut json).is_err());
374
375 let mut json = json!({
376 "type": "object",
377 "properties": {
378 "then": "...",
379 }
380 });
381
382 assert!(adapt_to_json_schema_subset(&mut json).is_err());
383
384 let mut json = json!({
385 "type": "object",
386 "properties": {
387 "else": "...",
388 }
389 });
390
391 assert!(adapt_to_json_schema_subset(&mut json).is_err());
392 }
393
394 #[test]
395 fn test_preprocess_json_schema_adds_additional_properties() {
396 let mut json = json!({
397 "type": "object",
398 "properties": {
399 "name": {
400 "type": "string"
401 }
402 }
403 });
404
405 preprocess_json_schema(&mut json).unwrap();
406
407 assert_eq!(
408 json,
409 json!({
410 "type": "object",
411 "properties": {
412 "name": {
413 "type": "string"
414 }
415 },
416 "additionalProperties": false
417 })
418 );
419 }
420
421 #[test]
422 fn test_preprocess_json_schema_preserves_additional_properties() {
423 let mut json = json!({
424 "type": "object",
425 "properties": {
426 "name": {
427 "type": "string"
428 }
429 },
430 "additionalProperties": true
431 });
432
433 preprocess_json_schema(&mut json).unwrap();
434
435 assert_eq!(
436 json,
437 json!({
438 "type": "object",
439 "properties": {
440 "name": {
441 "type": "string"
442 }
443 },
444 "additionalProperties": true
445 })
446 );
447 }
448}