From af5a9fabc61285883e8a6a808afc149587cfb0a2 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 23 Apr 2024 20:49:29 -0700 Subject: [PATCH] Include root schema as parameters for tool calling (#10914) Allows `LanguageModelTool`s to include nested structures, by exposing the definitions section of their JSON Schema. Release Notes: - N/A --- Cargo.lock | 1 + crates/assistant2/Cargo.toml | 1 + .../examples/chat-with-functions.rs | 218 ++++++++++++++++++ crates/assistant_tooling/src/registry.rs | 3 +- crates/assistant_tooling/src/tool.rs | 23 +- 5 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 crates/assistant2/examples/chat-with-functions.rs diff --git a/Cargo.lock b/Cargo.lock index 85bc6dfd0bb7e47559b3d040a0b0231c9dc35547..0caa43fc13500fa4501630314ad973629350cc47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -391,6 +391,7 @@ dependencies = [ "node_runtime", "open_ai", "project", + "rand 0.8.5", "release_channel", "rich_text", "schemars", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 060dbaa98b2054c527843e2c862ad5e96877179e..886a84c863b32e0ab704a0651b8a48866d6afee9 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -46,6 +46,7 @@ language = { workspace = true, features = ["test-support"] } languages.workspace = true node_runtime.workspace = true project = { workspace = true, features = ["test-support"] } +rand.workspace = true release_channel.workspace = true settings = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant2/examples/chat-with-functions.rs b/crates/assistant2/examples/chat-with-functions.rs new file mode 100644 index 0000000000000000000000000000000000000000..15d3c968a4e4285c89a79648ad51a7dc5b5c5931 --- /dev/null +++ b/crates/assistant2/examples/chat-with-functions.rs @@ -0,0 +1,218 @@ +use anyhow::Context as _; +use assets::Assets; +use assistant2::AssistantPanel; +use assistant_tooling::{LanguageModelTool, ToolRegistry}; +use client::Client; +use gpui::{actions, AnyElement, App, AppContext, KeyBinding, Task, View, WindowOptions}; +use language::LanguageRegistry; +use project::Project; +use rand::Rng; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{KeymapFile, DEFAULT_KEYMAP_PATH}; +use std::sync::Arc; +use theme::LoadThemes; +use ui::{div, prelude::*, Render}; +use util::ResultExt as _; + +actions!(example, [Quit]); + +struct RollDiceTool {} + +impl RollDiceTool { + fn new() -> Self { + Self {} + } +} + +#[derive(Serialize, Deserialize, JsonSchema, Clone)] +#[serde(rename_all = "snake_case")] +enum Die { + D6 = 6, + D20 = 20, +} + +impl Die { + fn into_str(&self) -> &'static str { + match self { + Die::D6 => "d6", + Die::D20 => "d20", + } + } +} + +#[derive(Serialize, Deserialize, JsonSchema, Clone)] +struct DiceParams { + /// The number of dice to roll. + num_dice: u8, + /// Which die to roll. Defaults to a d6 if not provided. + die_type: Option, +} + +#[derive(Serialize, Deserialize)] +struct DieRoll { + die: Die, + roll: u8, +} + +impl DieRoll { + fn render(&self) -> AnyElement { + match self.die { + Die::D6 => { + let face = match self.roll { + 6 => div().child("⚅"), + 5 => div().child("⚄"), + 4 => div().child("⚃"), + 3 => div().child("⚂"), + 2 => div().child("⚁"), + 1 => div().child("⚀"), + _ => div().child("😅"), + }; + face.text_3xl().into_any_element() + } + _ => div() + .child(format!("{}", self.roll)) + .text_3xl() + .into_any_element(), + } + } +} + +#[derive(Serialize, Deserialize)] +struct DiceRoll { + rolls: Vec, +} + +impl LanguageModelTool for RollDiceTool { + type Input = DiceParams; + type Output = DiceRoll; + + fn name(&self) -> String { + "roll_dice".to_string() + } + + fn description(&self) -> String { + "Rolls N many dice and returns the results.".to_string() + } + + fn execute(&self, input: &Self::Input, _cx: &AppContext) -> Task> { + let rolls = (0..input.num_dice) + .map(|_| { + let die_type = input.die_type.as_ref().unwrap_or(&Die::D6).clone(); + + DieRoll { + die: die_type.clone(), + roll: rand::thread_rng().gen_range(1..=die_type as u8), + } + }) + .collect(); + + return Task::ready(Ok(DiceRoll { rolls })); + } + + fn render( + _tool_call_id: &str, + _input: &Self::Input, + output: &Self::Output, + _cx: &mut WindowContext, + ) -> gpui::AnyElement { + h_flex() + .children( + output + .rolls + .iter() + .map(|roll| div().p_2().child(roll.render())), + ) + .into_any_element() + } + + fn format(_input: &Self::Input, output: &Self::Output) -> String { + let mut result = String::new(); + for roll in &output.rolls { + let die = &roll.die; + result.push_str(&format!("{}: {}\n", die.into_str(), roll.roll)); + } + result + } +} + +fn main() { + env_logger::init(); + App::new().with_assets(Assets).run(|cx| { + cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None))); + cx.on_action(|_: &Quit, cx: &mut AppContext| { + cx.quit(); + }); + + settings::init(cx); + language::init(cx); + Project::init_settings(cx); + editor::init(cx); + theme::init(LoadThemes::JustBase, cx); + Assets.load_fonts(cx).unwrap(); + KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap(); + client::init_settings(cx); + release_channel::init("0.130.0", cx); + + let client = Client::production(cx); + { + let client = client.clone(); + cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await }) + .detach_and_log_err(cx); + } + assistant2::init(client.clone(), cx); + + let language_registry = Arc::new(LanguageRegistry::new( + Task::ready(()), + cx.background_executor().clone(), + )); + let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client()); + languages::init(language_registry.clone(), node_runtime, cx); + + cx.spawn(|cx| async move { + cx.update(|cx| { + let mut tool_registry = ToolRegistry::new(); + tool_registry + .register(RollDiceTool::new()) + .context("failed to register DummyTool") + .log_err(); + + let tool_registry = Arc::new(tool_registry); + + println!("Tools registered"); + for definition in tool_registry.definitions() { + println!("{}", definition); + } + + cx.open_window(WindowOptions::default(), |cx| { + cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) + }); + cx.activate(true); + }) + }) + .detach_and_log_err(cx); + }) +} + +struct Example { + assistant_panel: View, +} + +impl Example { + fn new( + language_registry: Arc, + tool_registry: Arc, + cx: &mut ViewContext, + ) -> Self { + Self { + assistant_panel: cx + .new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)), + } + } +} + +impl Render for Example { + fn render(&mut self, _cx: &mut ViewContext) -> impl ui::prelude::IntoElement { + div().size_full().child(self.assistant_panel.clone()) + } +} diff --git a/crates/assistant_tooling/src/registry.rs b/crates/assistant_tooling/src/registry.rs index 8c969c0d800b618ad9100f7f64b09b74e17ff022..ac5930cac403d06a4501d1979ed88e40da83fe23 100644 --- a/crates/assistant_tooling/src/registry.rs +++ b/crates/assistant_tooling/src/registry.rs @@ -256,7 +256,7 @@ mod test { let expected = ToolFunctionDefinition { name: "get_current_weather".to_string(), description: "Fetches the current weather for a given location.".to_string(), - parameters: schema_for!(WeatherQuery).schema, + parameters: schema_for!(WeatherQuery), }; assert_eq!(tools[0].name, expected.name); @@ -267,6 +267,7 @@ mod test { assert_eq!( expected_schema, json!({ + "$schema": "http://json-schema.org/draft-07/schema#", "title": "WeatherQuery", "type": "object", "properties": { diff --git a/crates/assistant_tooling/src/tool.rs b/crates/assistant_tooling/src/tool.rs index b63e2901c63caa68a04573026a1d44c2c8d9a287..a3b021a04e0a896a6dc9edfa4e312dad78826165 100644 --- a/crates/assistant_tooling/src/tool.rs +++ b/crates/assistant_tooling/src/tool.rs @@ -1,8 +1,11 @@ use anyhow::Result; use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext}; -use schemars::{schema::SchemaObject, schema_for, JsonSchema}; +use schemars::{schema::RootSchema, schema_for, JsonSchema}; use serde::Deserialize; -use std::{any::Any, fmt::Debug}; +use std::{ + any::Any, + fmt::{Debug, Display}, +}; #[derive(Default, Deserialize)] pub struct ToolFunctionCall { @@ -89,7 +92,17 @@ impl ToolFunctionCallResult { pub struct ToolFunctionDefinition { pub name: String, pub description: String, - pub parameters: SchemaObject, + pub parameters: RootSchema, +} + +impl Display for ToolFunctionDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let schema = serde_json::to_string(&self.parameters).ok(); + let schema = schema.unwrap_or("None".to_string()); + write!(f, "Name: {}:\n", self.name)?; + write!(f, "Description: {}\n", self.description)?; + write!(f, "Parameters: {}", schema) + } } impl Debug for ToolFunctionDefinition { @@ -124,10 +137,12 @@ pub trait LanguageModelTool { /// The OpenAI Function definition for the tool, for direct use with OpenAI's API. fn definition(&self) -> ToolFunctionDefinition { + let root_schema = schema_for!(Self::Input); + ToolFunctionDefinition { name: self.name(), description: self.description(), - parameters: schema_for!(Self::Input).schema, + parameters: root_schema, } }