1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
26 if let Value::Object(obj) = json {
27 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
28
29 for key in UNSUPPORTED_KEYS {
30 if obj.contains_key(key) {
31 return Err(anyhow::anyhow!(
32 "Schema cannot be made compatible because it contains \"{}\" ",
33 key
34 ));
35 }
36 }
37
38 const KEYS_TO_REMOVE: [&str; 5] = [
39 "format",
40 "additionalProperties",
41 "exclusiveMinimum",
42 "exclusiveMaximum",
43 "optional",
44 ];
45 for key in KEYS_TO_REMOVE {
46 obj.remove(key);
47 }
48
49 // If a type is not specified for an input parameter, add a default type
50 if matches!(obj.get("description"), Some(Value::String(_)))
51 && !obj.contains_key("type")
52 && !(obj.contains_key("anyOf")
53 || obj.contains_key("oneOf")
54 || obj.contains_key("allOf"))
55 {
56 obj.insert("type".to_string(), Value::String("string".to_string()));
57 }
58
59 // Handle oneOf -> anyOf conversion
60 if let Some(subschemas) = obj.get_mut("oneOf") {
61 if subschemas.is_array() {
62 let subschemas_clone = subschemas.clone();
63 obj.remove("oneOf");
64 obj.insert("anyOf".to_string(), subschemas_clone);
65 }
66 }
67
68 // Recursively process all nested objects and arrays
69 for (_, value) in obj.iter_mut() {
70 if let Value::Object(_) | Value::Array(_) = value {
71 adapt_to_json_schema_subset(value)?;
72 }
73 }
74 } else if let Value::Array(arr) = json {
75 for item in arr.iter_mut() {
76 adapt_to_json_schema_subset(item)?;
77 }
78 }
79 Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use serde_json::json;
86
87 #[test]
88 fn test_transform_adds_type_when_missing() {
89 let mut json = json!({
90 "description": "A test field without type"
91 });
92
93 adapt_to_json_schema_subset(&mut json).unwrap();
94
95 assert_eq!(
96 json,
97 json!({
98 "description": "A test field without type",
99 "type": "string"
100 })
101 );
102
103 // Ensure that we do not add a type if it is an object
104 let mut json = json!({
105 "description": {
106 "value": "abc",
107 "type": "string"
108 }
109 });
110
111 adapt_to_json_schema_subset(&mut json).unwrap();
112
113 assert_eq!(
114 json,
115 json!({
116 "description": {
117 "value": "abc",
118 "type": "string"
119 }
120 })
121 );
122 }
123
124 #[test]
125 fn test_transform_removes_unsupported_keys() {
126 let mut json = json!({
127 "description": "A test field",
128 "type": "integer",
129 "format": "uint32",
130 "exclusiveMinimum": 0,
131 "exclusiveMaximum": 100,
132 "additionalProperties": false,
133 "optional": true
134 });
135
136 adapt_to_json_schema_subset(&mut json).unwrap();
137
138 assert_eq!(
139 json,
140 json!({
141 "description": "A test field",
142 "type": "integer"
143 })
144 );
145 }
146
147 #[test]
148 fn test_transform_one_of_to_any_of() {
149 let mut json = json!({
150 "description": "A test field",
151 "oneOf": [
152 { "type": "string" },
153 { "type": "integer" }
154 ]
155 });
156
157 adapt_to_json_schema_subset(&mut json).unwrap();
158
159 assert_eq!(
160 json,
161 json!({
162 "description": "A test field",
163 "anyOf": [
164 { "type": "string" },
165 { "type": "integer" }
166 ]
167 })
168 );
169 }
170
171 #[test]
172 fn test_transform_nested_objects() {
173 let mut json = json!({
174 "type": "object",
175 "properties": {
176 "nested": {
177 "oneOf": [
178 { "type": "string" },
179 { "type": "null" }
180 ],
181 "format": "email"
182 }
183 }
184 });
185
186 adapt_to_json_schema_subset(&mut json).unwrap();
187
188 assert_eq!(
189 json,
190 json!({
191 "type": "object",
192 "properties": {
193 "nested": {
194 "anyOf": [
195 { "type": "string" },
196 { "type": "null" }
197 ]
198 }
199 }
200 })
201 );
202 }
203
204 #[test]
205 fn test_transform_fails_if_unsupported_keys_exist() {
206 let mut json = json!({
207 "type": "object",
208 "properties": {
209 "$ref": "#/definitions/User",
210 }
211 });
212
213 assert!(adapt_to_json_schema_subset(&mut json).is_err());
214
215 let mut json = json!({
216 "type": "object",
217 "properties": {
218 "if": "...",
219 }
220 });
221
222 assert!(adapt_to_json_schema_subset(&mut json).is_err());
223
224 let mut json = json!({
225 "type": "object",
226 "properties": {
227 "then": "...",
228 }
229 });
230
231 assert!(adapt_to_json_schema_subset(&mut json).is_err());
232
233 let mut json = json!({
234 "type": "object",
235 "properties": {
236 "else": "...",
237 }
238 });
239
240 assert!(adapt_to_json_schema_subset(&mut json).is_err());
241 }
242}