1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
26 if let Value::Object(obj) = json {
27 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
28
29 for key in UNSUPPORTED_KEYS {
30 if obj.contains_key(key) {
31 return Err(anyhow::anyhow!(
32 "Schema cannot be made compatible because it contains \"{}\" ",
33 key
34 ));
35 }
36 }
37
38 obj.remove("format");
39
40 if let Some(default) = obj.get("default") {
41 let is_null = default.is_null();
42 // Default is not supported, so we need to remove it
43 obj.remove("default");
44 if is_null {
45 obj.insert("nullable".to_string(), Value::Bool(true));
46 }
47 }
48
49 // If a type is not specified for an input parameter, add a default type
50 if obj.contains_key("description")
51 && !obj.contains_key("type")
52 && !(obj.contains_key("anyOf")
53 || obj.contains_key("oneOf")
54 || obj.contains_key("allOf"))
55 {
56 obj.insert("type".to_string(), Value::String("string".to_string()));
57 }
58
59 // Handle oneOf -> anyOf conversion
60 if let Some(subschemas) = obj.get_mut("oneOf") {
61 if subschemas.is_array() {
62 let subschemas_clone = subschemas.clone();
63 obj.remove("oneOf");
64 obj.insert("anyOf".to_string(), subschemas_clone);
65 }
66 }
67
68 // Recursively process all nested objects and arrays
69 for (_, value) in obj.iter_mut() {
70 if let Value::Object(_) | Value::Array(_) = value {
71 adapt_to_json_schema_subset(value)?;
72 }
73 }
74 } else if let Value::Array(arr) = json {
75 for item in arr.iter_mut() {
76 adapt_to_json_schema_subset(item)?;
77 }
78 }
79 Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use serde_json::json;
86
87 #[test]
88 fn test_transform_default_null_to_nullable() {
89 let mut json = json!({
90 "description": "A test field",
91 "type": "string",
92 "default": null
93 });
94
95 adapt_to_json_schema_subset(&mut json).unwrap();
96
97 assert_eq!(
98 json,
99 json!({
100 "description": "A test field",
101 "type": "string",
102 "nullable": true
103 })
104 );
105 }
106
107 #[test]
108 fn test_transform_adds_type_when_missing() {
109 let mut json = json!({
110 "description": "A test field without type"
111 });
112
113 adapt_to_json_schema_subset(&mut json).unwrap();
114
115 assert_eq!(
116 json,
117 json!({
118 "description": "A test field without type",
119 "type": "string"
120 })
121 );
122 }
123
124 #[test]
125 fn test_transform_removes_format() {
126 let mut json = json!({
127 "description": "A test field",
128 "type": "integer",
129 "format": "uint32"
130 });
131
132 adapt_to_json_schema_subset(&mut json).unwrap();
133
134 assert_eq!(
135 json,
136 json!({
137 "description": "A test field",
138 "type": "integer"
139 })
140 );
141 }
142
143 #[test]
144 fn test_transform_one_of_to_any_of() {
145 let mut json = json!({
146 "description": "A test field",
147 "oneOf": [
148 { "type": "string" },
149 { "type": "integer" }
150 ]
151 });
152
153 adapt_to_json_schema_subset(&mut json).unwrap();
154
155 assert_eq!(
156 json,
157 json!({
158 "description": "A test field",
159 "anyOf": [
160 { "type": "string" },
161 { "type": "integer" }
162 ]
163 })
164 );
165 }
166
167 #[test]
168 fn test_transform_nested_objects() {
169 let mut json = json!({
170 "type": "object",
171 "properties": {
172 "nested": {
173 "oneOf": [
174 { "type": "string" },
175 { "type": "null" }
176 ],
177 "format": "email"
178 }
179 }
180 });
181
182 adapt_to_json_schema_subset(&mut json).unwrap();
183
184 assert_eq!(
185 json,
186 json!({
187 "type": "object",
188 "properties": {
189 "nested": {
190 "anyOf": [
191 { "type": "string" },
192 { "type": "null" }
193 ]
194 }
195 }
196 })
197 );
198 }
199
200 #[test]
201 fn test_transform_fails_if_unsupported_keys_exist() {
202 let mut json = json!({
203 "type": "object",
204 "properties": {
205 "$ref": "#/definitions/User",
206 }
207 });
208
209 assert!(adapt_to_json_schema_subset(&mut json).is_err());
210
211 let mut json = json!({
212 "type": "object",
213 "properties": {
214 "if": "...",
215 }
216 });
217
218 assert!(adapt_to_json_schema_subset(&mut json).is_err());
219
220 let mut json = json!({
221 "type": "object",
222 "properties": {
223 "then": "...",
224 }
225 });
226
227 assert!(adapt_to_json_schema_subset(&mut json).is_err());
228
229 let mut json = json!({
230 "type": "object",
231 "properties": {
232 "else": "...",
233 }
234 });
235
236 assert!(adapt_to_json_schema_subset(&mut json).is_err());
237 }
238}