schema.rs

  1use anyhow::Result;
  2use language_model::LanguageModelToolSchemaFormat;
  3use schemars::{
  4    JsonSchema,
  5    schema::{RootSchema, Schema, SchemaObject},
  6};
  7
  8pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
  9    let schema = root_schema_for::<T>(format);
 10    schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
 11}
 12
 13pub fn schema_to_json(
 14    schema: &RootSchema,
 15    format: LanguageModelToolSchemaFormat,
 16) -> Result<serde_json::Value> {
 17    let mut value = serde_json::to_value(schema)?;
 18    match format {
 19        LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
 20        LanguageModelToolSchemaFormat::JsonSchemaSubset => {
 21            transform_fields_to_json_schema_subset(&mut value);
 22            Ok(value)
 23        }
 24    }
 25}
 26
 27fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
 28    let mut generator = match format {
 29        LanguageModelToolSchemaFormat::JsonSchema => schemars::SchemaGenerator::default(),
 30        LanguageModelToolSchemaFormat::JsonSchemaSubset => {
 31            schemars::r#gen::SchemaSettings::default()
 32                .with(|settings| {
 33                    settings.meta_schema = None;
 34                    settings.inline_subschemas = true;
 35                    settings
 36                        .visitors
 37                        .push(Box::new(TransformToJsonSchemaSubsetVisitor));
 38                })
 39                .into_generator()
 40        }
 41    };
 42    generator.root_schema_for::<T>()
 43}
 44
 45#[derive(Debug, Clone)]
 46struct TransformToJsonSchemaSubsetVisitor;
 47
 48impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
 49    fn visit_root_schema(&mut self, root: &mut RootSchema) {
 50        schemars::visit::visit_root_schema(self, root)
 51    }
 52
 53    fn visit_schema(&mut self, schema: &mut Schema) {
 54        schemars::visit::visit_schema(self, schema)
 55    }
 56
 57    fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
 58        // Ensure that the type field is not an array, this happens when we use
 59        // Option<T>, the type will be [T, "null"].
 60        if let Some(instance_type) = schema.instance_type.take() {
 61            schema.instance_type = match instance_type {
 62                schemars::schema::SingleOrVec::Single(t) => {
 63                    Some(schemars::schema::SingleOrVec::Single(t))
 64                }
 65                schemars::schema::SingleOrVec::Vec(items) => items
 66                    .into_iter()
 67                    .next()
 68                    .map(schemars::schema::SingleOrVec::from),
 69            };
 70        }
 71
 72        // One of is not supported, use anyOf instead.
 73        if let Some(subschema) = schema.subschemas.as_mut() {
 74            if let Some(one_of) = subschema.one_of.take() {
 75                subschema.any_of = Some(one_of);
 76            }
 77        }
 78
 79        schemars::visit::visit_schema_object(self, schema)
 80    }
 81}
 82
 83fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
 84    if let serde_json::Value::Object(obj) = json {
 85        if let Some(default) = obj.get("default") {
 86            let is_null = default.is_null();
 87            //Default is not supported, so we need to remove it.
 88            obj.remove("default");
 89            if is_null {
 90                obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
 91            }
 92        }
 93
 94        // If a type is not specified for an input parameter we need to add it.
 95        if obj.contains_key("description")
 96            && !obj.contains_key("type")
 97            && !(obj.contains_key("anyOf")
 98                || obj.contains_key("oneOf")
 99                || obj.contains_key("allOf"))
100        {
101            obj.insert(
102                "type".to_string(),
103                serde_json::Value::String("string".to_string()),
104            );
105        }
106
107        //Format field is only partially supported (e.g. not uint compatibility)
108        obj.remove("format");
109
110        for (_, value) in obj.iter_mut() {
111            if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
112                transform_fields_to_json_schema_subset(value);
113            }
114        }
115    } else if let serde_json::Value::Array(arr) = json {
116        for item in arr.iter_mut() {
117            transform_fields_to_json_schema_subset(item);
118        }
119    }
120}