1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
26 if let Value::Object(obj) = json {
27 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
28
29 for key in UNSUPPORTED_KEYS {
30 if obj.contains_key(key) {
31 return Err(anyhow::anyhow!(
32 "Schema cannot be made compatible because it contains \"{}\" ",
33 key
34 ));
35 }
36 }
37
38 const KEYS_TO_REMOVE: [&str; 4] = [
39 "format",
40 "additionalProperties",
41 "exclusiveMinimum",
42 "exclusiveMaximum",
43 ];
44 for key in KEYS_TO_REMOVE {
45 obj.remove(key);
46 }
47
48 if let Some(default) = obj.get("default") {
49 let is_null = default.is_null();
50 // Default is not supported, so we need to remove it
51 obj.remove("default");
52 if is_null {
53 obj.insert("nullable".to_string(), Value::Bool(true));
54 }
55 }
56
57 // If a type is not specified for an input parameter, add a default type
58 if matches!(obj.get("description"), Some(Value::String(_)))
59 && !obj.contains_key("type")
60 && !(obj.contains_key("anyOf")
61 || obj.contains_key("oneOf")
62 || obj.contains_key("allOf"))
63 {
64 obj.insert("type".to_string(), Value::String("string".to_string()));
65 }
66
67 // Handle oneOf -> anyOf conversion
68 if let Some(subschemas) = obj.get_mut("oneOf") {
69 if subschemas.is_array() {
70 let subschemas_clone = subschemas.clone();
71 obj.remove("oneOf");
72 obj.insert("anyOf".to_string(), subschemas_clone);
73 }
74 }
75
76 // Recursively process all nested objects and arrays
77 for (_, value) in obj.iter_mut() {
78 if let Value::Object(_) | Value::Array(_) = value {
79 adapt_to_json_schema_subset(value)?;
80 }
81 }
82 } else if let Value::Array(arr) = json {
83 for item in arr.iter_mut() {
84 adapt_to_json_schema_subset(item)?;
85 }
86 }
87 Ok(())
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use serde_json::json;
94
95 #[test]
96 fn test_transform_default_null_to_nullable() {
97 let mut json = json!({
98 "description": "A test field",
99 "type": "string",
100 "default": null
101 });
102
103 adapt_to_json_schema_subset(&mut json).unwrap();
104
105 assert_eq!(
106 json,
107 json!({
108 "description": "A test field",
109 "type": "string",
110 "nullable": true
111 })
112 );
113 }
114
115 #[test]
116 fn test_transform_adds_type_when_missing() {
117 let mut json = json!({
118 "description": "A test field without type"
119 });
120
121 adapt_to_json_schema_subset(&mut json).unwrap();
122
123 assert_eq!(
124 json,
125 json!({
126 "description": "A test field without type",
127 "type": "string"
128 })
129 );
130
131 // Ensure that we do not add a type if it is an object
132 let mut json = json!({
133 "description": {
134 "value": "abc",
135 "type": "string"
136 }
137 });
138
139 adapt_to_json_schema_subset(&mut json).unwrap();
140
141 assert_eq!(
142 json,
143 json!({
144 "description": {
145 "value": "abc",
146 "type": "string"
147 }
148 })
149 );
150 }
151
152 #[test]
153 fn test_transform_removes_unsupported_keys() {
154 let mut json = json!({
155 "description": "A test field",
156 "type": "integer",
157 "format": "uint32",
158 "exclusiveMinimum": 0,
159 "exclusiveMaximum": 100,
160 "additionalProperties": false
161 });
162
163 adapt_to_json_schema_subset(&mut json).unwrap();
164
165 assert_eq!(
166 json,
167 json!({
168 "description": "A test field",
169 "type": "integer"
170 })
171 );
172 }
173
174 #[test]
175 fn test_transform_one_of_to_any_of() {
176 let mut json = json!({
177 "description": "A test field",
178 "oneOf": [
179 { "type": "string" },
180 { "type": "integer" }
181 ]
182 });
183
184 adapt_to_json_schema_subset(&mut json).unwrap();
185
186 assert_eq!(
187 json,
188 json!({
189 "description": "A test field",
190 "anyOf": [
191 { "type": "string" },
192 { "type": "integer" }
193 ]
194 })
195 );
196 }
197
198 #[test]
199 fn test_transform_nested_objects() {
200 let mut json = json!({
201 "type": "object",
202 "properties": {
203 "nested": {
204 "oneOf": [
205 { "type": "string" },
206 { "type": "null" }
207 ],
208 "format": "email"
209 }
210 }
211 });
212
213 adapt_to_json_schema_subset(&mut json).unwrap();
214
215 assert_eq!(
216 json,
217 json!({
218 "type": "object",
219 "properties": {
220 "nested": {
221 "anyOf": [
222 { "type": "string" },
223 { "type": "null" }
224 ]
225 }
226 }
227 })
228 );
229 }
230
231 #[test]
232 fn test_transform_fails_if_unsupported_keys_exist() {
233 let mut json = json!({
234 "type": "object",
235 "properties": {
236 "$ref": "#/definitions/User",
237 }
238 });
239
240 assert!(adapt_to_json_schema_subset(&mut json).is_err());
241
242 let mut json = json!({
243 "type": "object",
244 "properties": {
245 "if": "...",
246 }
247 });
248
249 assert!(adapt_to_json_schema_subset(&mut json).is_err());
250
251 let mut json = json!({
252 "type": "object",
253 "properties": {
254 "then": "...",
255 }
256 });
257
258 assert!(adapt_to_json_schema_subset(&mut json).is_err());
259
260 let mut json = json!({
261 "type": "object",
262 "properties": {
263 "else": "...",
264 }
265 });
266
267 assert!(adapt_to_json_schema_subset(&mut json).is_err());
268 }
269}