1use anyhow::Result;
2use schemars::{
3 JsonSchema, Schema,
4 generate::SchemaSettings,
5 transform::{Transform, transform_subschemas},
6};
7use serde_json::Value;
8
9/// Indicates the format used to define the input schema for a language model tool.
10#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
11pub enum LanguageModelToolSchemaFormat {
12 /// A JSON schema, see https://json-schema.org
13 JsonSchema,
14 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
15 JsonSchemaSubset,
16}
17
18pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
19 let mut generator = match format {
20 LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07()
21 .with(|settings| {
22 settings.meta_schema = None;
23 settings.inline_subschemas = true;
24 })
25 .into_generator(),
26 LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
27 .with(|settings| {
28 settings.meta_schema = None;
29 settings.inline_subschemas = true;
30 })
31 .with_transform(ToJsonSchemaSubsetTransform)
32 .into_generator(),
33 };
34 generator.root_schema_for::<T>()
35}
36
37#[derive(Debug, Clone)]
38struct ToJsonSchemaSubsetTransform;
39
40impl Transform for ToJsonSchemaSubsetTransform {
41 fn transform(&mut self, schema: &mut Schema) {
42 // Ensure that the type field is not an array, this happens when we use
43 // Option<T>, the type will be [T, "null"].
44 if let Some(type_field) = schema.get_mut("type")
45 && let Some(types) = type_field.as_array()
46 && let Some(first_type) = types.first()
47 {
48 *type_field = first_type.clone();
49 }
50
51 // oneOf is not supported, use anyOf instead
52 if let Some(one_of) = schema.remove("oneOf") {
53 schema.insert("anyOf".to_string(), one_of);
54 }
55
56 transform_subschemas(self, schema);
57 }
58}
59
60/// Tries to adapt a JSON schema representation to be compatible with the specified format.
61///
62/// If the json cannot be made compatible with the specified format, an error is returned.
63pub fn adapt_schema_to_format(
64 json: &mut Value,
65 format: LanguageModelToolSchemaFormat,
66) -> Result<()> {
67 if let Value::Object(obj) = json {
68 obj.remove("$schema");
69 obj.remove("title");
70 obj.remove("description");
71 }
72
73 match format {
74 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
75 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
76 }
77}
78
79fn preprocess_json_schema(json: &mut Value) -> Result<()> {
80 if let Value::Object(obj) = json
81 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
82 {
83 if !obj.contains_key("additionalProperties") {
84 obj.insert("additionalProperties".to_string(), Value::Bool(false));
85 }
86
87 if !obj.contains_key("properties") {
88 obj.insert("properties".to_string(), Value::Object(Default::default()));
89 }
90 }
91 Ok(())
92}
93
94fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
95 if let Value::Object(obj) = json {
96 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
97
98 for key in UNSUPPORTED_KEYS {
99 anyhow::ensure!(
100 !obj.contains_key(key),
101 "Schema cannot be made compatible because it contains \"{key}\""
102 );
103 }
104
105 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
106 ("format", |value| value.is_string()),
107 ("additionalProperties", |_| true),
108 ("propertyNames", |_| true),
109 ("exclusiveMinimum", |value| value.is_number()),
110 ("exclusiveMaximum", |value| value.is_number()),
111 ("optional", |value| value.is_boolean()),
112 ];
113 for (key, predicate) in KEYS_TO_REMOVE {
114 if let Some(value) = obj.get(key)
115 && predicate(value)
116 {
117 obj.remove(key);
118 }
119 }
120
121 if matches!(obj.get("description"), Some(Value::String(_)))
122 && !obj.contains_key("type")
123 && !(obj.contains_key("anyOf")
124 || obj.contains_key("oneOf")
125 || obj.contains_key("allOf"))
126 {
127 obj.insert("type".to_string(), Value::String("string".to_string()));
128 }
129
130 if let Some(subschemas) = obj.get_mut("oneOf")
131 && subschemas.is_array()
132 {
133 let subschemas_clone = subschemas.clone();
134 obj.remove("oneOf");
135 obj.insert("anyOf".to_string(), subschemas_clone);
136 }
137
138 for (_, value) in obj.iter_mut() {
139 if let Value::Object(_) | Value::Array(_) = value {
140 adapt_to_json_schema_subset(value)?;
141 }
142 }
143 } else if let Value::Array(arr) = json {
144 for item in arr.iter_mut() {
145 adapt_to_json_schema_subset(item)?;
146 }
147 }
148 Ok(())
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use serde_json::json;
155
156 #[test]
157 fn test_transform_adds_type_when_missing() {
158 let mut json = json!({
159 "description": "A test field without type"
160 });
161
162 adapt_to_json_schema_subset(&mut json).unwrap();
163
164 assert_eq!(
165 json,
166 json!({
167 "description": "A test field without type",
168 "type": "string"
169 })
170 );
171
172 let mut json = json!({
173 "description": {
174 "value": "abc",
175 "type": "string"
176 }
177 });
178
179 adapt_to_json_schema_subset(&mut json).unwrap();
180
181 assert_eq!(
182 json,
183 json!({
184 "description": {
185 "value": "abc",
186 "type": "string"
187 }
188 })
189 );
190 }
191
192 #[test]
193 fn test_transform_removes_unsupported_keys() {
194 let mut json = json!({
195 "description": "A test field",
196 "type": "integer",
197 "format": "uint32",
198 "exclusiveMinimum": 0,
199 "exclusiveMaximum": 100,
200 "additionalProperties": false,
201 "optional": true
202 });
203
204 adapt_to_json_schema_subset(&mut json).unwrap();
205
206 assert_eq!(
207 json,
208 json!({
209 "description": "A test field",
210 "type": "integer"
211 })
212 );
213
214 let mut json = json!({
215 "description": "A test field",
216 "type": "integer",
217 "format": {},
218 });
219
220 adapt_to_json_schema_subset(&mut json).unwrap();
221
222 assert_eq!(
223 json,
224 json!({
225 "description": "A test field",
226 "type": "integer",
227 "format": {},
228 })
229 );
230
231 let mut json = json!({
232 "type": "object",
233 "properties": {
234 "name": { "type": "string" }
235 },
236 "additionalProperties": { "type": "string" },
237 "propertyNames": { "pattern": "^[A-Za-z]+$" }
238 });
239
240 adapt_to_json_schema_subset(&mut json).unwrap();
241
242 assert_eq!(
243 json,
244 json!({
245 "type": "object",
246 "properties": {
247 "name": { "type": "string" }
248 }
249 })
250 );
251 }
252
253 #[test]
254 fn test_transform_one_of_to_any_of() {
255 let mut json = json!({
256 "description": "A test field",
257 "oneOf": [
258 { "type": "string" },
259 { "type": "integer" }
260 ]
261 });
262
263 adapt_to_json_schema_subset(&mut json).unwrap();
264
265 assert_eq!(
266 json,
267 json!({
268 "description": "A test field",
269 "anyOf": [
270 { "type": "string" },
271 { "type": "integer" }
272 ]
273 })
274 );
275 }
276
277 #[test]
278 fn test_transform_nested_objects() {
279 let mut json = json!({
280 "type": "object",
281 "properties": {
282 "nested": {
283 "oneOf": [
284 { "type": "string" },
285 { "type": "null" }
286 ],
287 "format": "email"
288 }
289 }
290 });
291
292 adapt_to_json_schema_subset(&mut json).unwrap();
293
294 assert_eq!(
295 json,
296 json!({
297 "type": "object",
298 "properties": {
299 "nested": {
300 "anyOf": [
301 { "type": "string" },
302 { "type": "null" }
303 ]
304 }
305 }
306 })
307 );
308 }
309
310 #[test]
311 fn test_transform_fails_if_unsupported_keys_exist() {
312 let mut json = json!({
313 "type": "object",
314 "properties": {
315 "$ref": "#/definitions/User",
316 }
317 });
318
319 assert!(adapt_to_json_schema_subset(&mut json).is_err());
320
321 let mut json = json!({
322 "type": "object",
323 "properties": {
324 "if": "...",
325 }
326 });
327
328 assert!(adapt_to_json_schema_subset(&mut json).is_err());
329
330 let mut json = json!({
331 "type": "object",
332 "properties": {
333 "then": "...",
334 }
335 });
336
337 assert!(adapt_to_json_schema_subset(&mut json).is_err());
338
339 let mut json = json!({
340 "type": "object",
341 "properties": {
342 "else": "...",
343 }
344 });
345
346 assert!(adapt_to_json_schema_subset(&mut json).is_err());
347 }
348
349 #[test]
350 fn test_preprocess_json_schema_adds_additional_properties() {
351 let mut json = json!({
352 "type": "object",
353 "properties": {
354 "name": {
355 "type": "string"
356 }
357 }
358 });
359
360 preprocess_json_schema(&mut json).unwrap();
361
362 assert_eq!(
363 json,
364 json!({
365 "type": "object",
366 "properties": {
367 "name": {
368 "type": "string"
369 }
370 },
371 "additionalProperties": false
372 })
373 );
374 }
375
376 #[test]
377 fn test_preprocess_json_schema_preserves_additional_properties() {
378 let mut json = json!({
379 "type": "object",
380 "properties": {
381 "name": {
382 "type": "string"
383 }
384 },
385 "additionalProperties": true
386 });
387
388 preprocess_json_schema(&mut json).unwrap();
389
390 assert_eq!(
391 json,
392 json!({
393 "type": "object",
394 "properties": {
395 "name": {
396 "type": "string"
397 }
398 },
399 "additionalProperties": true
400 })
401 );
402 }
403}