1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 match format {
14 LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
15 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
16 }
17}
18
19/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
20fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
21 if let Value::Object(obj) = json {
22 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
23
24 for key in UNSUPPORTED_KEYS {
25 if obj.contains_key(key) {
26 return Err(anyhow::anyhow!(
27 "Schema cannot be made compatible because it contains \"{}\" ",
28 key
29 ));
30 }
31 }
32
33 const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
34 for key in KEYS_TO_REMOVE {
35 obj.remove(key);
36 }
37
38 if let Some(default) = obj.get("default") {
39 let is_null = default.is_null();
40 // Default is not supported, so we need to remove it
41 obj.remove("default");
42 if is_null {
43 obj.insert("nullable".to_string(), Value::Bool(true));
44 }
45 }
46
47 // If a type is not specified for an input parameter, add a default type
48 if obj.contains_key("description")
49 && !obj.contains_key("type")
50 && !(obj.contains_key("anyOf")
51 || obj.contains_key("oneOf")
52 || obj.contains_key("allOf"))
53 {
54 obj.insert("type".to_string(), Value::String("string".to_string()));
55 }
56
57 // Handle oneOf -> anyOf conversion
58 if let Some(subschemas) = obj.get_mut("oneOf") {
59 if subschemas.is_array() {
60 let subschemas_clone = subschemas.clone();
61 obj.remove("oneOf");
62 obj.insert("anyOf".to_string(), subschemas_clone);
63 }
64 }
65
66 // Recursively process all nested objects and arrays
67 for (_, value) in obj.iter_mut() {
68 if let Value::Object(_) | Value::Array(_) = value {
69 adapt_to_json_schema_subset(value)?;
70 }
71 }
72 } else if let Value::Array(arr) = json {
73 for item in arr.iter_mut() {
74 adapt_to_json_schema_subset(item)?;
75 }
76 }
77 Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use serde_json::json;
84
85 #[test]
86 fn test_transform_default_null_to_nullable() {
87 let mut json = json!({
88 "description": "A test field",
89 "type": "string",
90 "default": null
91 });
92
93 adapt_to_json_schema_subset(&mut json).unwrap();
94
95 assert_eq!(
96 json,
97 json!({
98 "description": "A test field",
99 "type": "string",
100 "nullable": true
101 })
102 );
103 }
104
105 #[test]
106 fn test_transform_adds_type_when_missing() {
107 let mut json = json!({
108 "description": "A test field without type"
109 });
110
111 adapt_to_json_schema_subset(&mut json).unwrap();
112
113 assert_eq!(
114 json,
115 json!({
116 "description": "A test field without type",
117 "type": "string"
118 })
119 );
120 }
121
122 #[test]
123 fn test_transform_removes_format() {
124 let mut json = json!({
125 "description": "A test field",
126 "type": "integer",
127 "format": "uint32"
128 });
129
130 adapt_to_json_schema_subset(&mut json).unwrap();
131
132 assert_eq!(
133 json,
134 json!({
135 "description": "A test field",
136 "type": "integer"
137 })
138 );
139 }
140
141 #[test]
142 fn test_transform_one_of_to_any_of() {
143 let mut json = json!({
144 "description": "A test field",
145 "oneOf": [
146 { "type": "string" },
147 { "type": "integer" }
148 ]
149 });
150
151 adapt_to_json_schema_subset(&mut json).unwrap();
152
153 assert_eq!(
154 json,
155 json!({
156 "description": "A test field",
157 "anyOf": [
158 { "type": "string" },
159 { "type": "integer" }
160 ]
161 })
162 );
163 }
164
165 #[test]
166 fn test_transform_nested_objects() {
167 let mut json = json!({
168 "type": "object",
169 "properties": {
170 "nested": {
171 "oneOf": [
172 { "type": "string" },
173 { "type": "null" }
174 ],
175 "format": "email"
176 }
177 }
178 });
179
180 adapt_to_json_schema_subset(&mut json).unwrap();
181
182 assert_eq!(
183 json,
184 json!({
185 "type": "object",
186 "properties": {
187 "nested": {
188 "anyOf": [
189 { "type": "string" },
190 { "type": "null" }
191 ]
192 }
193 }
194 })
195 );
196 }
197
198 #[test]
199 fn test_transform_fails_if_unsupported_keys_exist() {
200 let mut json = json!({
201 "type": "object",
202 "properties": {
203 "$ref": "#/definitions/User",
204 }
205 });
206
207 assert!(adapt_to_json_schema_subset(&mut json).is_err());
208
209 let mut json = json!({
210 "type": "object",
211 "properties": {
212 "if": "...",
213 }
214 });
215
216 assert!(adapt_to_json_schema_subset(&mut json).is_err());
217
218 let mut json = json!({
219 "type": "object",
220 "properties": {
221 "then": "...",
222 }
223 });
224
225 assert!(adapt_to_json_schema_subset(&mut json).is_err());
226
227 let mut json = json!({
228 "type": "object",
229 "properties": {
230 "else": "...",
231 }
232 });
233
234 assert!(adapt_to_json_schema_subset(&mut json).is_err());
235 }
236}