1use anyhow::Result;
2use language_model::LanguageModelToolSchemaFormat;
3use schemars::{
4 JsonSchema,
5 schema::{RootSchema, Schema, SchemaObject},
6};
7
8pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
9 let schema = root_schema_for::<T>(format);
10 schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
11}
12
13pub fn schema_to_json(
14 schema: &RootSchema,
15 format: LanguageModelToolSchemaFormat,
16) -> Result<serde_json::Value> {
17 let mut value = serde_json::to_value(schema)?;
18 match format {
19 LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
20 LanguageModelToolSchemaFormat::JsonSchemaSubset => {
21 transform_fields_to_json_schema_subset(&mut value);
22 Ok(value)
23 }
24 }
25}
26
27fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
28 let mut generator = match format {
29 LanguageModelToolSchemaFormat::JsonSchema => schemars::SchemaGenerator::default(),
30 LanguageModelToolSchemaFormat::JsonSchemaSubset => {
31 schemars::r#gen::SchemaSettings::default()
32 .with(|settings| {
33 settings.meta_schema = None;
34 settings.inline_subschemas = true;
35 settings
36 .visitors
37 .push(Box::new(TransformToJsonSchemaSubsetVisitor));
38 })
39 .into_generator()
40 }
41 };
42 generator.root_schema_for::<T>()
43}
44
45#[derive(Debug, Clone)]
46struct TransformToJsonSchemaSubsetVisitor;
47
48impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
49 fn visit_root_schema(&mut self, root: &mut RootSchema) {
50 schemars::visit::visit_root_schema(self, root)
51 }
52
53 fn visit_schema(&mut self, schema: &mut Schema) {
54 schemars::visit::visit_schema(self, schema)
55 }
56
57 fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
58 // Ensure that the type field is not an array, this happens when we use
59 // Option<T>, the type will be [T, "null"].
60 if let Some(instance_type) = schema.instance_type.take() {
61 schema.instance_type = match instance_type {
62 schemars::schema::SingleOrVec::Single(t) => {
63 Some(schemars::schema::SingleOrVec::Single(t))
64 }
65 schemars::schema::SingleOrVec::Vec(items) => items
66 .into_iter()
67 .next()
68 .map(schemars::schema::SingleOrVec::from),
69 };
70 }
71
72 // One of is not supported, use anyOf instead.
73 if let Some(subschema) = schema.subschemas.as_mut() {
74 if let Some(one_of) = subschema.one_of.take() {
75 subschema.any_of = Some(one_of);
76 }
77 }
78
79 schemars::visit::visit_schema_object(self, schema)
80 }
81}
82
83fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
84 if let serde_json::Value::Object(obj) = json {
85 if let Some(default) = obj.get("default") {
86 let is_null = default.is_null();
87 //Default is not supported, so we need to remove it.
88 obj.remove("default");
89 if is_null {
90 obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
91 }
92 }
93
94 // If a type is not specified for an input parameter we need to add it.
95 if obj.contains_key("description")
96 && !obj.contains_key("type")
97 && !(obj.contains_key("anyOf")
98 || obj.contains_key("oneOf")
99 || obj.contains_key("allOf"))
100 {
101 obj.insert(
102 "type".to_string(),
103 serde_json::Value::String("string".to_string()),
104 );
105 }
106
107 //Format field is only partially supported (e.g. not uint compatibility)
108 obj.remove("format");
109
110 for (_, value) in obj.iter_mut() {
111 if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
112 transform_fields_to_json_schema_subset(value);
113 }
114 }
115 } else if let serde_json::Value::Array(arr) = json {
116 for item in arr.iter_mut() {
117 transform_fields_to_json_schema_subset(item);
118 }
119 }
120}