1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
25 // `additionalProperties` defaults to `false` unless explicitly specified.
26 // This prevents models from hallucinating tool parameters.
27 if let Value::Object(obj) = json
28 && matches!(obj.get("type"), Some(Value::String(s)) if s == "object") {
29 if !obj.contains_key("additionalProperties") {
30 obj.insert("additionalProperties".to_string(), Value::Bool(false));
31 }
32
33 // OpenAI API requires non-missing `properties`
34 if !obj.contains_key("properties") {
35 obj.insert("properties".to_string(), Value::Object(Default::default()));
36 }
37 }
38 Ok(())
39}
40
41/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
42fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
43 if let Value::Object(obj) = json {
44 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
45
46 for key in UNSUPPORTED_KEYS {
47 anyhow::ensure!(
48 !obj.contains_key(key),
49 "Schema cannot be made compatible because it contains \"{key}\""
50 );
51 }
52
53 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
54 ("format", |value| value.is_string()),
55 ("additionalProperties", |value| value.is_boolean()),
56 ("exclusiveMinimum", |value| value.is_number()),
57 ("exclusiveMaximum", |value| value.is_number()),
58 ("optional", |value| value.is_boolean()),
59 ];
60 for (key, predicate) in KEYS_TO_REMOVE {
61 if let Some(value) = obj.get(key)
62 && predicate(value) {
63 obj.remove(key);
64 }
65 }
66
67 // If a type is not specified for an input parameter, add a default type
68 if matches!(obj.get("description"), Some(Value::String(_)))
69 && !obj.contains_key("type")
70 && !(obj.contains_key("anyOf")
71 || obj.contains_key("oneOf")
72 || obj.contains_key("allOf"))
73 {
74 obj.insert("type".to_string(), Value::String("string".to_string()));
75 }
76
77 // Handle oneOf -> anyOf conversion
78 if let Some(subschemas) = obj.get_mut("oneOf")
79 && subschemas.is_array() {
80 let subschemas_clone = subschemas.clone();
81 obj.remove("oneOf");
82 obj.insert("anyOf".to_string(), subschemas_clone);
83 }
84
85 // Recursively process all nested objects and arrays
86 for (_, value) in obj.iter_mut() {
87 if let Value::Object(_) | Value::Array(_) = value {
88 adapt_to_json_schema_subset(value)?;
89 }
90 }
91 } else if let Value::Array(arr) = json {
92 for item in arr.iter_mut() {
93 adapt_to_json_schema_subset(item)?;
94 }
95 }
96 Ok(())
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use serde_json::json;
103
104 #[test]
105 fn test_transform_adds_type_when_missing() {
106 let mut json = json!({
107 "description": "A test field without type"
108 });
109
110 adapt_to_json_schema_subset(&mut json).unwrap();
111
112 assert_eq!(
113 json,
114 json!({
115 "description": "A test field without type",
116 "type": "string"
117 })
118 );
119
120 // Ensure that we do not add a type if it is an object
121 let mut json = json!({
122 "description": {
123 "value": "abc",
124 "type": "string"
125 }
126 });
127
128 adapt_to_json_schema_subset(&mut json).unwrap();
129
130 assert_eq!(
131 json,
132 json!({
133 "description": {
134 "value": "abc",
135 "type": "string"
136 }
137 })
138 );
139 }
140
141 #[test]
142 fn test_transform_removes_unsupported_keys() {
143 let mut json = json!({
144 "description": "A test field",
145 "type": "integer",
146 "format": "uint32",
147 "exclusiveMinimum": 0,
148 "exclusiveMaximum": 100,
149 "additionalProperties": false,
150 "optional": true
151 });
152
153 adapt_to_json_schema_subset(&mut json).unwrap();
154
155 assert_eq!(
156 json,
157 json!({
158 "description": "A test field",
159 "type": "integer"
160 })
161 );
162
163 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
164 let mut json = json!({
165 "description": "A test field",
166 "type": "integer",
167 "format": {},
168 });
169
170 adapt_to_json_schema_subset(&mut json).unwrap();
171
172 assert_eq!(
173 json,
174 json!({
175 "description": "A test field",
176 "type": "integer",
177 "format": {},
178 })
179 );
180 }
181
182 #[test]
183 fn test_transform_one_of_to_any_of() {
184 let mut json = json!({
185 "description": "A test field",
186 "oneOf": [
187 { "type": "string" },
188 { "type": "integer" }
189 ]
190 });
191
192 adapt_to_json_schema_subset(&mut json).unwrap();
193
194 assert_eq!(
195 json,
196 json!({
197 "description": "A test field",
198 "anyOf": [
199 { "type": "string" },
200 { "type": "integer" }
201 ]
202 })
203 );
204 }
205
206 #[test]
207 fn test_transform_nested_objects() {
208 let mut json = json!({
209 "type": "object",
210 "properties": {
211 "nested": {
212 "oneOf": [
213 { "type": "string" },
214 { "type": "null" }
215 ],
216 "format": "email"
217 }
218 }
219 });
220
221 adapt_to_json_schema_subset(&mut json).unwrap();
222
223 assert_eq!(
224 json,
225 json!({
226 "type": "object",
227 "properties": {
228 "nested": {
229 "anyOf": [
230 { "type": "string" },
231 { "type": "null" }
232 ]
233 }
234 }
235 })
236 );
237 }
238
239 #[test]
240 fn test_transform_fails_if_unsupported_keys_exist() {
241 let mut json = json!({
242 "type": "object",
243 "properties": {
244 "$ref": "#/definitions/User",
245 }
246 });
247
248 assert!(adapt_to_json_schema_subset(&mut json).is_err());
249
250 let mut json = json!({
251 "type": "object",
252 "properties": {
253 "if": "...",
254 }
255 });
256
257 assert!(adapt_to_json_schema_subset(&mut json).is_err());
258
259 let mut json = json!({
260 "type": "object",
261 "properties": {
262 "then": "...",
263 }
264 });
265
266 assert!(adapt_to_json_schema_subset(&mut json).is_err());
267
268 let mut json = json!({
269 "type": "object",
270 "properties": {
271 "else": "...",
272 }
273 });
274
275 assert!(adapt_to_json_schema_subset(&mut json).is_err());
276 }
277
278 #[test]
279 fn test_preprocess_json_schema_adds_additional_properties() {
280 let mut json = json!({
281 "type": "object",
282 "properties": {
283 "name": {
284 "type": "string"
285 }
286 }
287 });
288
289 preprocess_json_schema(&mut json).unwrap();
290
291 assert_eq!(
292 json,
293 json!({
294 "type": "object",
295 "properties": {
296 "name": {
297 "type": "string"
298 }
299 },
300 "additionalProperties": false
301 })
302 );
303 }
304
305 #[test]
306 fn test_preprocess_json_schema_preserves_additional_properties() {
307 let mut json = json!({
308 "type": "object",
309 "properties": {
310 "name": {
311 "type": "string"
312 }
313 },
314 "additionalProperties": true
315 });
316
317 preprocess_json_schema(&mut json).unwrap();
318
319 assert_eq!(
320 json,
321 json!({
322 "type": "object",
323 "properties": {
324 "name": {
325 "type": "string"
326 }
327 },
328 "additionalProperties": true
329 })
330 );
331 }
332}