1use anyhow::Result;
2use serde_json::Value;
3
4use crate::LanguageModelToolSchemaFormat;
5
6/// Tries to adapt a JSON schema representation to be compatible with the specified format.
7///
8/// If the json cannot be made compatible with the specified format, an error is returned.
9pub fn adapt_schema_to_format(
10 json: &mut Value,
11 format: LanguageModelToolSchemaFormat,
12) -> Result<()> {
13 if let Value::Object(obj) = json {
14 obj.remove("$schema");
15 obj.remove("title");
16 }
17
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
21 }
22}
23
24/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
25fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
26 if let Value::Object(obj) = json {
27 const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
28
29 for key in UNSUPPORTED_KEYS {
30 anyhow::ensure!(
31 !obj.contains_key(key),
32 "Schema cannot be made compatible because it contains \"{key}\""
33 );
34 }
35
36 const KEYS_TO_REMOVE: [&str; 5] = [
37 "format",
38 "additionalProperties",
39 "exclusiveMinimum",
40 "exclusiveMaximum",
41 "optional",
42 ];
43 for key in KEYS_TO_REMOVE {
44 obj.remove(key);
45 }
46
47 // If a type is not specified for an input parameter, add a default type
48 if matches!(obj.get("description"), Some(Value::String(_)))
49 && !obj.contains_key("type")
50 && !(obj.contains_key("anyOf")
51 || obj.contains_key("oneOf")
52 || obj.contains_key("allOf"))
53 {
54 obj.insert("type".to_string(), Value::String("string".to_string()));
55 }
56
57 // Handle oneOf -> anyOf conversion
58 if let Some(subschemas) = obj.get_mut("oneOf") {
59 if subschemas.is_array() {
60 let subschemas_clone = subschemas.clone();
61 obj.remove("oneOf");
62 obj.insert("anyOf".to_string(), subschemas_clone);
63 }
64 }
65
66 // Recursively process all nested objects and arrays
67 for (_, value) in obj.iter_mut() {
68 if let Value::Object(_) | Value::Array(_) = value {
69 adapt_to_json_schema_subset(value)?;
70 }
71 }
72 } else if let Value::Array(arr) = json {
73 for item in arr.iter_mut() {
74 adapt_to_json_schema_subset(item)?;
75 }
76 }
77 Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use serde_json::json;
84
85 #[test]
86 fn test_transform_adds_type_when_missing() {
87 let mut json = json!({
88 "description": "A test field without type"
89 });
90
91 adapt_to_json_schema_subset(&mut json).unwrap();
92
93 assert_eq!(
94 json,
95 json!({
96 "description": "A test field without type",
97 "type": "string"
98 })
99 );
100
101 // Ensure that we do not add a type if it is an object
102 let mut json = json!({
103 "description": {
104 "value": "abc",
105 "type": "string"
106 }
107 });
108
109 adapt_to_json_schema_subset(&mut json).unwrap();
110
111 assert_eq!(
112 json,
113 json!({
114 "description": {
115 "value": "abc",
116 "type": "string"
117 }
118 })
119 );
120 }
121
122 #[test]
123 fn test_transform_removes_unsupported_keys() {
124 let mut json = json!({
125 "description": "A test field",
126 "type": "integer",
127 "format": "uint32",
128 "exclusiveMinimum": 0,
129 "exclusiveMaximum": 100,
130 "additionalProperties": false,
131 "optional": true
132 });
133
134 adapt_to_json_schema_subset(&mut json).unwrap();
135
136 assert_eq!(
137 json,
138 json!({
139 "description": "A test field",
140 "type": "integer"
141 })
142 );
143 }
144
145 #[test]
146 fn test_transform_one_of_to_any_of() {
147 let mut json = json!({
148 "description": "A test field",
149 "oneOf": [
150 { "type": "string" },
151 { "type": "integer" }
152 ]
153 });
154
155 adapt_to_json_schema_subset(&mut json).unwrap();
156
157 assert_eq!(
158 json,
159 json!({
160 "description": "A test field",
161 "anyOf": [
162 { "type": "string" },
163 { "type": "integer" }
164 ]
165 })
166 );
167 }
168
169 #[test]
170 fn test_transform_nested_objects() {
171 let mut json = json!({
172 "type": "object",
173 "properties": {
174 "nested": {
175 "oneOf": [
176 { "type": "string" },
177 { "type": "null" }
178 ],
179 "format": "email"
180 }
181 }
182 });
183
184 adapt_to_json_schema_subset(&mut json).unwrap();
185
186 assert_eq!(
187 json,
188 json!({
189 "type": "object",
190 "properties": {
191 "nested": {
192 "anyOf": [
193 { "type": "string" },
194 { "type": "null" }
195 ]
196 }
197 }
198 })
199 );
200 }
201
202 #[test]
203 fn test_transform_fails_if_unsupported_keys_exist() {
204 let mut json = json!({
205 "type": "object",
206 "properties": {
207 "$ref": "#/definitions/User",
208 }
209 });
210
211 assert!(adapt_to_json_schema_subset(&mut json).is_err());
212
213 let mut json = json!({
214 "type": "object",
215 "properties": {
216 "if": "...",
217 }
218 });
219
220 assert!(adapt_to_json_schema_subset(&mut json).is_err());
221
222 let mut json = json!({
223 "type": "object",
224 "properties": {
225 "then": "...",
226 }
227 });
228
229 assert!(adapt_to_json_schema_subset(&mut json).is_err());
230
231 let mut json = json!({
232 "type": "object",
233 "properties": {
234 "else": "...",
235 }
236 });
237
238 assert!(adapt_to_json_schema_subset(&mut json).is_err());
239 }
240}