Cargo.lock 🔗
@@ -391,6 +391,7 @@ dependencies = [
"node_runtime",
"open_ai",
"project",
+ "rand 0.8.5",
"release_channel",
"rich_text",
"schemars",
Kyle Kelley created
Allows `LanguageModelTool`s to include nested structures, by exposing
the definitions section of their JSON Schema.
Release Notes:
- N/A
Cargo.lock | 1
crates/assistant2/Cargo.toml | 1
crates/assistant2/examples/chat-with-functions.rs | 218 +++++++++++++++++
crates/assistant_tooling/src/registry.rs | 3
crates/assistant_tooling/src/tool.rs | 23 +
5 files changed, 241 insertions(+), 5 deletions(-)
@@ -391,6 +391,7 @@ dependencies = [
"node_runtime",
"open_ai",
"project",
+ "rand 0.8.5",
"release_channel",
"rich_text",
"schemars",
@@ -46,6 +46,7 @@ language = { workspace = true, features = ["test-support"] }
languages.workspace = true
node_runtime.workspace = true
project = { workspace = true, features = ["test-support"] }
+rand.workspace = true
release_channel.workspace = true
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
@@ -0,0 +1,218 @@
+use anyhow::Context as _;
+use assets::Assets;
+use assistant2::AssistantPanel;
+use assistant_tooling::{LanguageModelTool, ToolRegistry};
+use client::Client;
+use gpui::{actions, AnyElement, App, AppContext, KeyBinding, Task, View, WindowOptions};
+use language::LanguageRegistry;
+use project::Project;
+use rand::Rng;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
+use std::sync::Arc;
+use theme::LoadThemes;
+use ui::{div, prelude::*, Render};
+use util::ResultExt as _;
+
+actions!(example, [Quit]);
+
+struct RollDiceTool {}
+
+impl RollDiceTool {
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+#[derive(Serialize, Deserialize, JsonSchema, Clone)]
+#[serde(rename_all = "snake_case")]
+enum Die {
+ D6 = 6,
+ D20 = 20,
+}
+
+impl Die {
+ fn into_str(&self) -> &'static str {
+ match self {
+ Die::D6 => "d6",
+ Die::D20 => "d20",
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, JsonSchema, Clone)]
+struct DiceParams {
+ /// The number of dice to roll.
+ num_dice: u8,
+ /// Which die to roll. Defaults to a d6 if not provided.
+ die_type: Option<Die>,
+}
+
+#[derive(Serialize, Deserialize)]
+struct DieRoll {
+ die: Die,
+ roll: u8,
+}
+
+impl DieRoll {
+ fn render(&self) -> AnyElement {
+ match self.die {
+ Die::D6 => {
+ let face = match self.roll {
+ 6 => div().child("⚅"),
+ 5 => div().child("⚄"),
+ 4 => div().child("⚃"),
+ 3 => div().child("⚂"),
+ 2 => div().child("⚁"),
+ 1 => div().child("⚀"),
+ _ => div().child("😅"),
+ };
+ face.text_3xl().into_any_element()
+ }
+ _ => div()
+ .child(format!("{}", self.roll))
+ .text_3xl()
+ .into_any_element(),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+struct DiceRoll {
+ rolls: Vec<DieRoll>,
+}
+
+impl LanguageModelTool for RollDiceTool {
+ type Input = DiceParams;
+ type Output = DiceRoll;
+
+ fn name(&self) -> String {
+ "roll_dice".to_string()
+ }
+
+ fn description(&self) -> String {
+ "Rolls N many dice and returns the results.".to_string()
+ }
+
+ fn execute(&self, input: &Self::Input, _cx: &AppContext) -> Task<gpui::Result<Self::Output>> {
+ let rolls = (0..input.num_dice)
+ .map(|_| {
+ let die_type = input.die_type.as_ref().unwrap_or(&Die::D6).clone();
+
+ DieRoll {
+ die: die_type.clone(),
+ roll: rand::thread_rng().gen_range(1..=die_type as u8),
+ }
+ })
+ .collect();
+
+ return Task::ready(Ok(DiceRoll { rolls }));
+ }
+
+ fn render(
+ _tool_call_id: &str,
+ _input: &Self::Input,
+ output: &Self::Output,
+ _cx: &mut WindowContext,
+ ) -> gpui::AnyElement {
+ h_flex()
+ .children(
+ output
+ .rolls
+ .iter()
+ .map(|roll| div().p_2().child(roll.render())),
+ )
+ .into_any_element()
+ }
+
+ fn format(_input: &Self::Input, output: &Self::Output) -> String {
+ let mut result = String::new();
+ for roll in &output.rolls {
+ let die = &roll.die;
+ result.push_str(&format!("{}: {}\n", die.into_str(), roll.roll));
+ }
+ result
+ }
+}
+
+fn main() {
+ env_logger::init();
+ App::new().with_assets(Assets).run(|cx| {
+ cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
+ cx.on_action(|_: &Quit, cx: &mut AppContext| {
+ cx.quit();
+ });
+
+ settings::init(cx);
+ language::init(cx);
+ Project::init_settings(cx);
+ editor::init(cx);
+ theme::init(LoadThemes::JustBase, cx);
+ Assets.load_fonts(cx).unwrap();
+ KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
+ client::init_settings(cx);
+ release_channel::init("0.130.0", cx);
+
+ let client = Client::production(cx);
+ {
+ let client = client.clone();
+ cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
+ .detach_and_log_err(cx);
+ }
+ assistant2::init(client.clone(), cx);
+
+ let language_registry = Arc::new(LanguageRegistry::new(
+ Task::ready(()),
+ cx.background_executor().clone(),
+ ));
+ let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
+ languages::init(language_registry.clone(), node_runtime, cx);
+
+ cx.spawn(|cx| async move {
+ cx.update(|cx| {
+ let mut tool_registry = ToolRegistry::new();
+ tool_registry
+ .register(RollDiceTool::new())
+ .context("failed to register DummyTool")
+ .log_err();
+
+ let tool_registry = Arc::new(tool_registry);
+
+ println!("Tools registered");
+ for definition in tool_registry.definitions() {
+ println!("{}", definition);
+ }
+
+ cx.open_window(WindowOptions::default(), |cx| {
+ cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
+ });
+ cx.activate(true);
+ })
+ })
+ .detach_and_log_err(cx);
+ })
+}
+
+struct Example {
+ assistant_panel: View<AssistantPanel>,
+}
+
+impl Example {
+ fn new(
+ language_registry: Arc<LanguageRegistry>,
+ tool_registry: Arc<ToolRegistry>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ Self {
+ assistant_panel: cx
+ .new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)),
+ }
+ }
+}
+
+impl Render for Example {
+ fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
+ div().size_full().child(self.assistant_panel.clone())
+ }
+}
@@ -256,7 +256,7 @@ mod test {
let expected = ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
- parameters: schema_for!(WeatherQuery).schema,
+ parameters: schema_for!(WeatherQuery),
};
assert_eq!(tools[0].name, expected.name);
@@ -267,6 +267,7 @@ mod test {
assert_eq!(
expected_schema,
json!({
+ "$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
@@ -1,8 +1,11 @@
use anyhow::Result;
use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
-use schemars::{schema::SchemaObject, schema_for, JsonSchema};
+use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::Deserialize;
-use std::{any::Any, fmt::Debug};
+use std::{
+ any::Any,
+ fmt::{Debug, Display},
+};
#[derive(Default, Deserialize)]
pub struct ToolFunctionCall {
@@ -89,7 +92,17 @@ impl ToolFunctionCallResult {
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
- pub parameters: SchemaObject,
+ pub parameters: RootSchema,
+}
+
+impl Display for ToolFunctionDefinition {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let schema = serde_json::to_string(&self.parameters).ok();
+ let schema = schema.unwrap_or("None".to_string());
+ write!(f, "Name: {}:\n", self.name)?;
+ write!(f, "Description: {}\n", self.description)?;
+ write!(f, "Parameters: {}", schema)
+ }
}
impl Debug for ToolFunctionDefinition {
@@ -124,10 +137,12 @@ pub trait LanguageModelTool {
/// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
+ let root_schema = schema_for!(Self::Input);
+
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
- parameters: schema_for!(Self::Input).schema,
+ parameters: root_schema,
}
}