1use anyhow::Result;
2use language_model::LanguageModelToolSchemaFormat;
3use schemars::{
4 JsonSchema, Schema,
5 generate::SchemaSettings,
6 transform::{Transform, transform_subschemas},
7};
8use serde_json::Value;
9
10pub(crate) fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
11 let mut generator = match format {
12 LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
13 LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
14 .with(|settings| {
15 settings.meta_schema = None;
16 settings.inline_subschemas = true;
17 })
18 .with_transform(ToJsonSchemaSubsetTransform)
19 .into_generator(),
20 };
21 generator.root_schema_for::<T>()
22}
23
24#[derive(Debug, Clone)]
25struct ToJsonSchemaSubsetTransform;
26
27impl Transform for ToJsonSchemaSubsetTransform {
28 fn transform(&mut self, schema: &mut Schema) {
29 // Ensure that the type field is not an array, this happens when we use
30 // Option<T>, the type will be [T, "null"].
31 if let Some(type_field) = schema.get_mut("type")
32 && let Some(types) = type_field.as_array()
33 && let Some(first_type) = types.first()
34 {
35 *type_field = first_type.clone();
36 }
37
38 // oneOf is not supported, use anyOf instead
39 if let Some(one_of) = schema.remove("oneOf") {
40 schema.insert("anyOf".to_string(), one_of);
41 }
42
43 transform_subschemas(self, schema);
44 }
45}
46
47/// Tries to adapt a JSON schema representation to be compatible with the specified format.
48///
49/// If the json cannot be made compatible with the specified format, an error is returned.
50pub fn adapt_schema_to_format(
51 json: &mut Value,
52 format: LanguageModelToolSchemaFormat,
53) -> Result<()> {
54 if let Value::Object(obj) = json {
55 obj.remove("$schema");
56 obj.remove("title");
57 }
58
59 match format {
60 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
61 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
62 }
63}
64
65fn preprocess_json_schema(json: &mut Value) -> Result<()> {
66 // `additionalProperties` defaults to `false` unless explicitly specified.
67 // This prevents models from hallucinating tool parameters.
68 if let Value::Object(obj) = json
69 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
70 {
71 if !obj.contains_key("additionalProperties") {
72 obj.insert("additionalProperties".to_string(), Value::Bool(false));
73 }
74
75 // OpenAI API requires non-missing `properties`
76 if !obj.contains_key("properties") {
77 obj.insert("properties".to_string(), Value::Object(Default::default()));
78 }
79 }
80 Ok(())
81}
82
83/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
84fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
85 if let Value::Object(obj) = json {
86 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
87
88 for key in UNSUPPORTED_KEYS {
89 anyhow::ensure!(
90 !obj.contains_key(key),
91 "Schema cannot be made compatible because it contains \"{key}\""
92 );
93 }
94
95 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
96 ("format", |value| value.is_string()),
97 ("additionalProperties", |value| value.is_boolean()),
98 ("exclusiveMinimum", |value| value.is_number()),
99 ("exclusiveMaximum", |value| value.is_number()),
100 ("optional", |value| value.is_boolean()),
101 ];
102 for (key, predicate) in KEYS_TO_REMOVE {
103 if let Some(value) = obj.get(key)
104 && predicate(value)
105 {
106 obj.remove(key);
107 }
108 }
109
110 // If a type is not specified for an input parameter, add a default type
111 if matches!(obj.get("description"), Some(Value::String(_)))
112 && !obj.contains_key("type")
113 && !(obj.contains_key("anyOf")
114 || obj.contains_key("oneOf")
115 || obj.contains_key("allOf"))
116 {
117 obj.insert("type".to_string(), Value::String("string".to_string()));
118 }
119
120 // Handle oneOf -> anyOf conversion
121 if let Some(subschemas) = obj.get_mut("oneOf")
122 && subschemas.is_array()
123 {
124 let subschemas_clone = subschemas.clone();
125 obj.remove("oneOf");
126 obj.insert("anyOf".to_string(), subschemas_clone);
127 }
128
129 // Recursively process all nested objects and arrays
130 for (_, value) in obj.iter_mut() {
131 if let Value::Object(_) | Value::Array(_) = value {
132 adapt_to_json_schema_subset(value)?;
133 }
134 }
135 } else if let Value::Array(arr) = json {
136 for item in arr.iter_mut() {
137 adapt_to_json_schema_subset(item)?;
138 }
139 }
140 Ok(())
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use serde_json::json;
147
148 #[test]
149 fn test_transform_adds_type_when_missing() {
150 let mut json = json!({
151 "description": "A test field without type"
152 });
153
154 adapt_to_json_schema_subset(&mut json).unwrap();
155
156 assert_eq!(
157 json,
158 json!({
159 "description": "A test field without type",
160 "type": "string"
161 })
162 );
163
164 // Ensure that we do not add a type if it is an object
165 let mut json = json!({
166 "description": {
167 "value": "abc",
168 "type": "string"
169 }
170 });
171
172 adapt_to_json_schema_subset(&mut json).unwrap();
173
174 assert_eq!(
175 json,
176 json!({
177 "description": {
178 "value": "abc",
179 "type": "string"
180 }
181 })
182 );
183 }
184
185 #[test]
186 fn test_transform_removes_unsupported_keys() {
187 let mut json = json!({
188 "description": "A test field",
189 "type": "integer",
190 "format": "uint32",
191 "exclusiveMinimum": 0,
192 "exclusiveMaximum": 100,
193 "additionalProperties": false,
194 "optional": true
195 });
196
197 adapt_to_json_schema_subset(&mut json).unwrap();
198
199 assert_eq!(
200 json,
201 json!({
202 "description": "A test field",
203 "type": "integer"
204 })
205 );
206
207 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
208 let mut json = json!({
209 "description": "A test field",
210 "type": "integer",
211 "format": {},
212 });
213
214 adapt_to_json_schema_subset(&mut json).unwrap();
215
216 assert_eq!(
217 json,
218 json!({
219 "description": "A test field",
220 "type": "integer",
221 "format": {},
222 })
223 );
224 }
225
226 #[test]
227 fn test_transform_one_of_to_any_of() {
228 let mut json = json!({
229 "description": "A test field",
230 "oneOf": [
231 { "type": "string" },
232 { "type": "integer" }
233 ]
234 });
235
236 adapt_to_json_schema_subset(&mut json).unwrap();
237
238 assert_eq!(
239 json,
240 json!({
241 "description": "A test field",
242 "anyOf": [
243 { "type": "string" },
244 { "type": "integer" }
245 ]
246 })
247 );
248 }
249
250 #[test]
251 fn test_transform_nested_objects() {
252 let mut json = json!({
253 "type": "object",
254 "properties": {
255 "nested": {
256 "oneOf": [
257 { "type": "string" },
258 { "type": "null" }
259 ],
260 "format": "email"
261 }
262 }
263 });
264
265 adapt_to_json_schema_subset(&mut json).unwrap();
266
267 assert_eq!(
268 json,
269 json!({
270 "type": "object",
271 "properties": {
272 "nested": {
273 "anyOf": [
274 { "type": "string" },
275 { "type": "null" }
276 ]
277 }
278 }
279 })
280 );
281 }
282
283 #[test]
284 fn test_transform_fails_if_unsupported_keys_exist() {
285 let mut json = json!({
286 "type": "object",
287 "properties": {
288 "$ref": "#/definitions/User",
289 }
290 });
291
292 assert!(adapt_to_json_schema_subset(&mut json).is_err());
293
294 let mut json = json!({
295 "type": "object",
296 "properties": {
297 "if": "...",
298 }
299 });
300
301 assert!(adapt_to_json_schema_subset(&mut json).is_err());
302
303 let mut json = json!({
304 "type": "object",
305 "properties": {
306 "then": "...",
307 }
308 });
309
310 assert!(adapt_to_json_schema_subset(&mut json).is_err());
311
312 let mut json = json!({
313 "type": "object",
314 "properties": {
315 "else": "...",
316 }
317 });
318
319 assert!(adapt_to_json_schema_subset(&mut json).is_err());
320 }
321
322 #[test]
323 fn test_preprocess_json_schema_adds_additional_properties() {
324 let mut json = json!({
325 "type": "object",
326 "properties": {
327 "name": {
328 "type": "string"
329 }
330 }
331 });
332
333 preprocess_json_schema(&mut json).unwrap();
334
335 assert_eq!(
336 json,
337 json!({
338 "type": "object",
339 "properties": {
340 "name": {
341 "type": "string"
342 }
343 },
344 "additionalProperties": false
345 })
346 );
347 }
348
349 #[test]
350 fn test_preprocess_json_schema_preserves_additional_properties() {
351 let mut json = json!({
352 "type": "object",
353 "properties": {
354 "name": {
355 "type": "string"
356 }
357 },
358 "additionalProperties": true
359 });
360
361 preprocess_json_schema(&mut json).unwrap();
362
363 assert_eq!(
364 json,
365 json!({
366 "type": "object",
367 "properties": {
368 "name": {
369 "type": "string"
370 }
371 },
372 "additionalProperties": true
373 })
374 );
375 }
376}