1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
25 // `additionalProperties` defaults to `false` unless explicitly specified.
26 // This prevents models from hallucinating tool parameters.
27 if let Value::Object(obj) = json
28 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
29 {
30 if !obj.contains_key("additionalProperties") {
31 obj.insert("additionalProperties".to_string(), Value::Bool(false));
32 }
33
34 // OpenAI API requires non-missing `properties`
35 if !obj.contains_key("properties") {
36 obj.insert("properties".to_string(), Value::Object(Default::default()));
37 }
38 }
39 Ok(())
40}
41
42/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
43fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
44 if let Value::Object(obj) = json {
45 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
46
47 for key in UNSUPPORTED_KEYS {
48 anyhow::ensure!(
49 !obj.contains_key(key),
50 "Schema cannot be made compatible because it contains \"{key}\""
51 );
52 }
53
54 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
55 ("format", |value| value.is_string()),
56 ("additionalProperties", |value| value.is_boolean()),
57 ("exclusiveMinimum", |value| value.is_number()),
58 ("exclusiveMaximum", |value| value.is_number()),
59 ("optional", |value| value.is_boolean()),
60 ];
61 for (key, predicate) in KEYS_TO_REMOVE {
62 if let Some(value) = obj.get(key)
63 && predicate(value)
64 {
65 obj.remove(key);
66 }
67 }
68
69 // If a type is not specified for an input parameter, add a default type
70 if matches!(obj.get("description"), Some(Value::String(_)))
71 && !obj.contains_key("type")
72 && !(obj.contains_key("anyOf")
73 || obj.contains_key("oneOf")
74 || obj.contains_key("allOf"))
75 {
76 obj.insert("type".to_string(), Value::String("string".to_string()));
77 }
78
79 // Handle oneOf -> anyOf conversion
80 if let Some(subschemas) = obj.get_mut("oneOf")
81 && subschemas.is_array()
82 {
83 let subschemas_clone = subschemas.clone();
84 obj.remove("oneOf");
85 obj.insert("anyOf".to_string(), subschemas_clone);
86 }
87
88 // Recursively process all nested objects and arrays
89 for (_, value) in obj.iter_mut() {
90 if let Value::Object(_) | Value::Array(_) = value {
91 adapt_to_json_schema_subset(value)?;
92 }
93 }
94 } else if let Value::Array(arr) = json {
95 for item in arr.iter_mut() {
96 adapt_to_json_schema_subset(item)?;
97 }
98 }
99 Ok(())
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use serde_json::json;
106
107 #[test]
108 fn test_transform_adds_type_when_missing() {
109 let mut json = json!({
110 "description": "A test field without type"
111 });
112
113 adapt_to_json_schema_subset(&mut json).unwrap();
114
115 assert_eq!(
116 json,
117 json!({
118 "description": "A test field without type",
119 "type": "string"
120 })
121 );
122
123 // Ensure that we do not add a type if it is an object
124 let mut json = json!({
125 "description": {
126 "value": "abc",
127 "type": "string"
128 }
129 });
130
131 adapt_to_json_schema_subset(&mut json).unwrap();
132
133 assert_eq!(
134 json,
135 json!({
136 "description": {
137 "value": "abc",
138 "type": "string"
139 }
140 })
141 );
142 }
143
144 #[test]
145 fn test_transform_removes_unsupported_keys() {
146 let mut json = json!({
147 "description": "A test field",
148 "type": "integer",
149 "format": "uint32",
150 "exclusiveMinimum": 0,
151 "exclusiveMaximum": 100,
152 "additionalProperties": false,
153 "optional": true
154 });
155
156 adapt_to_json_schema_subset(&mut json).unwrap();
157
158 assert_eq!(
159 json,
160 json!({
161 "description": "A test field",
162 "type": "integer"
163 })
164 );
165
166 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
167 let mut json = json!({
168 "description": "A test field",
169 "type": "integer",
170 "format": {},
171 });
172
173 adapt_to_json_schema_subset(&mut json).unwrap();
174
175 assert_eq!(
176 json,
177 json!({
178 "description": "A test field",
179 "type": "integer",
180 "format": {},
181 })
182 );
183 }
184
185 #[test]
186 fn test_transform_one_of_to_any_of() {
187 let mut json = json!({
188 "description": "A test field",
189 "oneOf": [
190 { "type": "string" },
191 { "type": "integer" }
192 ]
193 });
194
195 adapt_to_json_schema_subset(&mut json).unwrap();
196
197 assert_eq!(
198 json,
199 json!({
200 "description": "A test field",
201 "anyOf": [
202 { "type": "string" },
203 { "type": "integer" }
204 ]
205 })
206 );
207 }
208
209 #[test]
210 fn test_transform_nested_objects() {
211 let mut json = json!({
212 "type": "object",
213 "properties": {
214 "nested": {
215 "oneOf": [
216 { "type": "string" },
217 { "type": "null" }
218 ],
219 "format": "email"
220 }
221 }
222 });
223
224 adapt_to_json_schema_subset(&mut json).unwrap();
225
226 assert_eq!(
227 json,
228 json!({
229 "type": "object",
230 "properties": {
231 "nested": {
232 "anyOf": [
233 { "type": "string" },
234 { "type": "null" }
235 ]
236 }
237 }
238 })
239 );
240 }
241
242 #[test]
243 fn test_transform_fails_if_unsupported_keys_exist() {
244 let mut json = json!({
245 "type": "object",
246 "properties": {
247 "$ref": "#/definitions/User",
248 }
249 });
250
251 assert!(adapt_to_json_schema_subset(&mut json).is_err());
252
253 let mut json = json!({
254 "type": "object",
255 "properties": {
256 "if": "...",
257 }
258 });
259
260 assert!(adapt_to_json_schema_subset(&mut json).is_err());
261
262 let mut json = json!({
263 "type": "object",
264 "properties": {
265 "then": "...",
266 }
267 });
268
269 assert!(adapt_to_json_schema_subset(&mut json).is_err());
270
271 let mut json = json!({
272 "type": "object",
273 "properties": {
274 "else": "...",
275 }
276 });
277
278 assert!(adapt_to_json_schema_subset(&mut json).is_err());
279 }
280
281 #[test]
282 fn test_preprocess_json_schema_adds_additional_properties() {
283 let mut json = json!({
284 "type": "object",
285 "properties": {
286 "name": {
287 "type": "string"
288 }
289 }
290 });
291
292 preprocess_json_schema(&mut json).unwrap();
293
294 assert_eq!(
295 json,
296 json!({
297 "type": "object",
298 "properties": {
299 "name": {
300 "type": "string"
301 }
302 },
303 "additionalProperties": false
304 })
305 );
306 }
307
308 #[test]
309 fn test_preprocess_json_schema_preserves_additional_properties() {
310 let mut json = json!({
311 "type": "object",
312 "properties": {
313 "name": {
314 "type": "string"
315 }
316 },
317 "additionalProperties": true
318 });
319
320 preprocess_json_schema(&mut json).unwrap();
321
322 assert_eq!(
323 json,
324 json!({
325 "type": "object",
326 "properties": {
327 "name": {
328 "type": "string"
329 }
330 },
331 "additionalProperties": true
332 })
333 );
334 }
335}