1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24fn preprocess_json_schema(json: &mut Value) -> Result<()> {
25 // `additionalProperties` defaults to `false` unless explicitly specified.
26 // This prevents models from hallucinating tool parameters.
27 if let Value::Object(obj) = json {
28 if let Some(Value::String(type_str)) = obj.get("type") {
29 if type_str == "object" && !obj.contains_key("additionalProperties") {
30 obj.insert("additionalProperties".to_string(), Value::Bool(false));
31 }
32 }
33 }
34 Ok(())
35}
36
37/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
38fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
39 if let Value::Object(obj) = json {
40 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
41
42 for key in UNSUPPORTED_KEYS {
43 anyhow::ensure!(
44 !obj.contains_key(key),
45 "Schema cannot be made compatible because it contains \"{key}\""
46 );
47 }
48
49 const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 5] = [
50 ("format", |value| value.is_string()),
51 ("additionalProperties", |value| value.is_boolean()),
52 ("exclusiveMinimum", |value| value.is_number()),
53 ("exclusiveMaximum", |value| value.is_number()),
54 ("optional", |value| value.is_boolean()),
55 ];
56 for (key, predicate) in KEYS_TO_REMOVE {
57 if let Some(value) = obj.get(key) {
58 if predicate(value) {
59 obj.remove(key);
60 }
61 }
62 }
63
64 // If a type is not specified for an input parameter, add a default type
65 if matches!(obj.get("description"), Some(Value::String(_)))
66 && !obj.contains_key("type")
67 && !(obj.contains_key("anyOf")
68 || obj.contains_key("oneOf")
69 || obj.contains_key("allOf"))
70 {
71 obj.insert("type".to_string(), Value::String("string".to_string()));
72 }
73
74 // Handle oneOf -> anyOf conversion
75 if let Some(subschemas) = obj.get_mut("oneOf") {
76 if subschemas.is_array() {
77 let subschemas_clone = subschemas.clone();
78 obj.remove("oneOf");
79 obj.insert("anyOf".to_string(), subschemas_clone);
80 }
81 }
82
83 // Recursively process all nested objects and arrays
84 for (_, value) in obj.iter_mut() {
85 if let Value::Object(_) | Value::Array(_) = value {
86 adapt_to_json_schema_subset(value)?;
87 }
88 }
89 } else if let Value::Array(arr) = json {
90 for item in arr.iter_mut() {
91 adapt_to_json_schema_subset(item)?;
92 }
93 }
94 Ok(())
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use serde_json::json;
101
102 #[test]
103 fn test_transform_adds_type_when_missing() {
104 let mut json = json!({
105 "description": "A test field without type"
106 });
107
108 adapt_to_json_schema_subset(&mut json).unwrap();
109
110 assert_eq!(
111 json,
112 json!({
113 "description": "A test field without type",
114 "type": "string"
115 })
116 );
117
118 // Ensure that we do not add a type if it is an object
119 let mut json = json!({
120 "description": {
121 "value": "abc",
122 "type": "string"
123 }
124 });
125
126 adapt_to_json_schema_subset(&mut json).unwrap();
127
128 assert_eq!(
129 json,
130 json!({
131 "description": {
132 "value": "abc",
133 "type": "string"
134 }
135 })
136 );
137 }
138
139 #[test]
140 fn test_transform_removes_unsupported_keys() {
141 let mut json = json!({
142 "description": "A test field",
143 "type": "integer",
144 "format": "uint32",
145 "exclusiveMinimum": 0,
146 "exclusiveMaximum": 100,
147 "additionalProperties": false,
148 "optional": true
149 });
150
151 adapt_to_json_schema_subset(&mut json).unwrap();
152
153 assert_eq!(
154 json,
155 json!({
156 "description": "A test field",
157 "type": "integer"
158 })
159 );
160
161 // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
162 let mut json = json!({
163 "description": "A test field",
164 "type": "integer",
165 "format": {},
166 });
167
168 adapt_to_json_schema_subset(&mut json).unwrap();
169
170 assert_eq!(
171 json,
172 json!({
173 "description": "A test field",
174 "type": "integer",
175 "format": {},
176 })
177 );
178 }
179
180 #[test]
181 fn test_transform_one_of_to_any_of() {
182 let mut json = json!({
183 "description": "A test field",
184 "oneOf": [
185 { "type": "string" },
186 { "type": "integer" }
187 ]
188 });
189
190 adapt_to_json_schema_subset(&mut json).unwrap();
191
192 assert_eq!(
193 json,
194 json!({
195 "description": "A test field",
196 "anyOf": [
197 { "type": "string" },
198 { "type": "integer" }
199 ]
200 })
201 );
202 }
203
204 #[test]
205 fn test_transform_nested_objects() {
206 let mut json = json!({
207 "type": "object",
208 "properties": {
209 "nested": {
210 "oneOf": [
211 { "type": "string" },
212 { "type": "null" }
213 ],
214 "format": "email"
215 }
216 }
217 });
218
219 adapt_to_json_schema_subset(&mut json).unwrap();
220
221 assert_eq!(
222 json,
223 json!({
224 "type": "object",
225 "properties": {
226 "nested": {
227 "anyOf": [
228 { "type": "string" },
229 { "type": "null" }
230 ]
231 }
232 }
233 })
234 );
235 }
236
237 #[test]
238 fn test_transform_fails_if_unsupported_keys_exist() {
239 let mut json = json!({
240 "type": "object",
241 "properties": {
242 "$ref": "#/definitions/User",
243 }
244 });
245
246 assert!(adapt_to_json_schema_subset(&mut json).is_err());
247
248 let mut json = json!({
249 "type": "object",
250 "properties": {
251 "if": "...",
252 }
253 });
254
255 assert!(adapt_to_json_schema_subset(&mut json).is_err());
256
257 let mut json = json!({
258 "type": "object",
259 "properties": {
260 "then": "...",
261 }
262 });
263
264 assert!(adapt_to_json_schema_subset(&mut json).is_err());
265
266 let mut json = json!({
267 "type": "object",
268 "properties": {
269 "else": "...",
270 }
271 });
272
273 assert!(adapt_to_json_schema_subset(&mut json).is_err());
274 }
275
276 #[test]
277 fn test_preprocess_json_schema_adds_additional_properties() {
278 let mut json = json!({
279 "type": "object",
280 "properties": {
281 "name": {
282 "type": "string"
283 }
284 }
285 });
286
287 preprocess_json_schema(&mut json).unwrap();
288
289 assert_eq!(
290 json,
291 json!({
292 "type": "object",
293 "properties": {
294 "name": {
295 "type": "string"
296 }
297 },
298 "additionalProperties": false
299 })
300 );
301 }
302
303 #[test]
304 fn test_preprocess_json_schema_preserves_additional_properties() {
305 let mut json = json!({
306 "type": "object",
307 "properties": {
308 "name": {
309 "type": "string"
310 }
311 },
312 "additionalProperties": true
313 });
314
315 preprocess_json_schema(&mut json).unwrap();
316
317 assert_eq!(
318 json,
319 json!({
320 "type": "object",
321 "properties": {
322 "name": {
323 "type": "string"
324 }
325 },
326 "additionalProperties": true
327 })
328 );
329 }
330}