assistant: Add tool registry (#17331)

Marshall Bowers created

This PR adds a tool registry to hold tools that can be called by the
Assistant.

Currently we just have a `now` tool for retrieving the current datetime.

This is all behind the `assistant-tool-use` feature flag which currently
needs to be explicitly opted-in to in order for the LLM to see the
tools.

Release Notes:

- N/A

Change summary

Cargo.lock                                  | 15 +++++
Cargo.toml                                  |  2 
crates/assistant/Cargo.toml                 |  1 
crates/assistant/src/assistant.rs           |  9 +++
crates/assistant/src/context.rs             | 30 +++++++++
crates/assistant/src/tools.rs               |  1 
crates/assistant/src/tools/now_tool.rs      | 60 ++++++++++++++++++++
crates/assistant_tool/Cargo.toml            | 22 +++++++
crates/assistant_tool/LICENSE-GPL           |  1 
crates/assistant_tool/src/assistant_tool.rs | 35 +++++++++++
crates/assistant_tool/src/tool_registry.rs  | 69 +++++++++++++++++++++++
11 files changed, 243 insertions(+), 2 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -373,6 +373,7 @@ dependencies = [
  "anyhow",
  "assets",
  "assistant_slash_command",
+ "assistant_tool",
  "async-watch",
  "cargo_toml",
  "chrono",
@@ -454,6 +455,20 @@ dependencies = [
  "workspace",
 ]
 
+[[package]]
+name = "assistant_tool"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "derive_more",
+ "gpui",
+ "parking_lot",
+ "serde",
+ "serde_json",
+ "workspace",
+]
+
 [[package]]
 name = "async-attributes"
 version = "1.1.2"

Cargo.toml 🔗

@@ -6,6 +6,7 @@ members = [
     "crates/assets",
     "crates/assistant",
     "crates/assistant_slash_command",
+    "crates/assistant_tool",
     "crates/audio",
     "crates/auto_update",
     "crates/breadcrumbs",
@@ -181,6 +182,7 @@ anthropic = { path = "crates/anthropic" }
 assets = { path = "crates/assets" }
 assistant = { path = "crates/assistant" }
 assistant_slash_command = { path = "crates/assistant_slash_command" }
+assistant_tool = { path = "crates/assistant_tool" }
 audio = { path = "crates/audio" }
 auto_update = { path = "crates/auto_update" }
 breadcrumbs = { path = "crates/breadcrumbs" }

crates/assistant/Cargo.toml 🔗

@@ -25,6 +25,7 @@ anthropic = { workspace = true, features = ["schemars"] }
 anyhow.workspace = true
 assets.workspace = true
 assistant_slash_command.workspace = true
+assistant_tool.workspace = true
 async-watch.workspace = true
 cargo_toml.workspace = true
 chrono.workspace = true

crates/assistant/src/assistant.rs 🔗

@@ -13,11 +13,13 @@ pub(crate) mod slash_command_picker;
 pub mod slash_command_settings;
 mod streaming_diff;
 mod terminal_inline_assistant;
+mod tools;
 mod workflow;
 
 pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 use assistant_settings::AssistantSettings;
 use assistant_slash_command::SlashCommandRegistry;
+use assistant_tool::ToolRegistry;
 use client::{proto, Client};
 use command_palette_hooks::CommandPaletteFilter;
 pub use context::*;
@@ -214,6 +216,7 @@ pub fn init(
     prompt_library::init(cx);
     init_language_model_settings(cx);
     assistant_slash_command::init(cx);
+    assistant_tool::init(cx);
     assistant_panel::init(cx);
     context_servers::init(cx);
 
@@ -228,6 +231,7 @@ pub fn init(
     .map(Arc::new)
     .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
     register_slash_commands(Some(prompt_builder.clone()), cx);
+    register_tools(cx);
     inline_assistant::init(
         fs.clone(),
         prompt_builder.clone(),
@@ -401,6 +405,11 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
     }
 }
 
+fn register_tools(cx: &mut AppContext) {
+    let tool_registry = ToolRegistry::global(cx);
+    tool_registry.register_tool(tools::now_tool::NowTool);
+}
+
 pub fn humanize_token_count(count: usize) -> String {
     match count {
         0..=999 => count.to_string(),

crates/assistant/src/context.rs 🔗

@@ -9,9 +9,11 @@ use anyhow::{anyhow, Context as _, Result};
 use assistant_slash_command::{
     SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
 };
+use assistant_tool::ToolRegistry;
 use client::{self, proto, telemetry::Telemetry};
 use clock::ReplicaId;
 use collections::{HashMap, HashSet};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt};
 use fs::{Fs, RemoveOptions};
 use futures::{
     future::{self, Shared},
@@ -27,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
 use language_model::{
     LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
     LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
-    MessageContent, Role,
+    LanguageModelRequestTool, MessageContent, Role,
 };
 use open_ai::Model as OpenAiModel;
 use paths::{context_images_dir, contexts_dir};
@@ -1942,7 +1944,21 @@ impl Context {
         // Compute which messages to cache, including the last one.
         self.mark_cache_anchors(&model.cache_configuration(), false, cx);
 
-        let request = self.to_completion_request(cx);
+        let mut request = self.to_completion_request(cx);
+
+        if cx.has_flag::<ToolUseFeatureFlag>() {
+            let tool_registry = ToolRegistry::global(cx);
+            request.tools = tool_registry
+                .tools()
+                .into_iter()
+                .map(|tool| LanguageModelRequestTool {
+                    name: tool.name(),
+                    description: tool.description(),
+                    input_schema: tool.input_schema(),
+                })
+                .collect();
+        }
+
         let assistant_message = self
             .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
             .unwrap();
@@ -2788,6 +2804,16 @@ pub enum PendingSlashCommandStatus {
     Error(String),
 }
 
+pub(crate) struct ToolUseFeatureFlag;
+
+impl FeatureFlag for ToolUseFeatureFlag {
+    const NAME: &'static str = "assistant-tool-use";
+
+    fn enabled_for_staff() -> bool {
+        false
+    }
+}
+
 #[derive(Debug, Clone)]
 pub struct PendingToolUse {
     pub id: String,

crates/assistant/src/tools/now_tool.rs 🔗

@@ -0,0 +1,60 @@
+use std::sync::Arc;
+
+use anyhow::{anyhow, Result};
+use assistant_tool::Tool;
+use chrono::{Local, Utc};
+use gpui::{Task, WeakView, WindowContext};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+#[serde(rename_all = "snake_case")]
+pub enum Timezone {
+    /// Use UTC for the datetime.
+    Utc,
+    /// Use local time for the datetime.
+    Local,
+}
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct FileToolInput {
+    /// The timezone to use for the datetime.
+    timezone: Timezone,
+}
+
+pub struct NowTool;
+
+impl Tool for NowTool {
+    fn name(&self) -> String {
+        "now".into()
+    }
+
+    fn description(&self) -> String {
+        "Returns the current datetime in RFC 3339 format.".into()
+    }
+
+    fn input_schema(&self) -> serde_json::Value {
+        let schema = schemars::schema_for!(FileToolInput);
+        serde_json::to_value(&schema).unwrap()
+    }
+
+    fn run(
+        self: Arc<Self>,
+        input: serde_json::Value,
+        _workspace: WeakView<workspace::Workspace>,
+        _cx: &mut WindowContext,
+    ) -> Task<Result<String>> {
+        let input: FileToolInput = match serde_json::from_value(input) {
+            Ok(input) => input,
+            Err(err) => return Task::ready(Err(anyhow!(err))),
+        };
+
+        let now = match input.timezone {
+            Timezone::Utc => Utc::now().to_rfc3339(),
+            Timezone::Local => Local::now().to_rfc3339(),
+        };
+        let text = format!("The current datetime is {now}.");
+
+        Task::ready(Ok(text))
+    }
+}

crates/assistant_tool/Cargo.toml 🔗

@@ -0,0 +1,22 @@
+[package]
+name = "assistant_tool"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/assistant_tool.rs"
+
+[dependencies]
+anyhow.workspace = true
+collections.workspace = true
+derive_more.workspace = true
+gpui.workspace = true
+parking_lot.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+workspace.workspace = true

crates/assistant_tool/src/assistant_tool.rs 🔗

@@ -0,0 +1,35 @@
+mod tool_registry;
+
+use std::sync::Arc;
+
+use anyhow::Result;
+use gpui::{AppContext, Task, WeakView, WindowContext};
+use workspace::Workspace;
+
+pub use tool_registry::*;
+
+pub fn init(cx: &mut AppContext) {
+    ToolRegistry::default_global(cx);
+}
+
+/// A tool that can be used by a language model.
+pub trait Tool: 'static + Send + Sync {
+    /// Returns the name of the tool.
+    fn name(&self) -> String;
+
+    /// Returns the description of the tool.
+    fn description(&self) -> String;
+
+    /// Returns the JSON schema that describes the tool's input.
+    fn input_schema(&self) -> serde_json::Value {
+        serde_json::Value::Object(serde_json::Map::default())
+    }
+
+    /// Runs the tool with the provided input.
+    fn run(
+        self: Arc<Self>,
+        input: serde_json::Value,
+        workspace: WeakView<Workspace>,
+        cx: &mut WindowContext,
+    ) -> Task<Result<String>>;
+}

crates/assistant_tool/src/tool_registry.rs 🔗

@@ -0,0 +1,69 @@
+use std::sync::Arc;
+
+use collections::HashMap;
+use derive_more::{Deref, DerefMut};
+use gpui::Global;
+use gpui::{AppContext, ReadGlobal};
+use parking_lot::RwLock;
+
+use crate::Tool;
+
+#[derive(Default, Deref, DerefMut)]
+struct GlobalToolRegistry(Arc<ToolRegistry>);
+
+impl Global for GlobalToolRegistry {}
+
+#[derive(Default)]
+struct ToolRegistryState {
+    tools: HashMap<Arc<str>, Arc<dyn Tool>>,
+}
+
+#[derive(Default)]
+pub struct ToolRegistry {
+    state: RwLock<ToolRegistryState>,
+}
+
+impl ToolRegistry {
+    /// Returns the global [`ToolRegistry`].
+    pub fn global(cx: &AppContext) -> Arc<Self> {
+        GlobalToolRegistry::global(cx).0.clone()
+    }
+
+    /// Returns the global [`ToolRegistry`].
+    ///
+    /// Inserts a default [`ToolRegistry`] if one does not yet exist.
+    pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
+        cx.default_global::<GlobalToolRegistry>().0.clone()
+    }
+
+    pub fn new() -> Arc<Self> {
+        Arc::new(Self {
+            state: RwLock::new(ToolRegistryState {
+                tools: HashMap::default(),
+            }),
+        })
+    }
+
+    /// Registers the provided [`Tool`].
+    pub fn register_tool(&self, tool: impl Tool) {
+        let mut state = self.state.write();
+        let tool_name: Arc<str> = tool.name().into();
+        state.tools.insert(tool_name, Arc::new(tool));
+    }
+
+    /// Unregisters the provided [`Tool`].
+    pub fn unregister_tool(&self, tool: impl Tool) {
+        self.unregister_tool_by_name(tool.name().as_str())
+    }
+
+    /// Unregisters the tool with the given name.
+    pub fn unregister_tool_by_name(&self, tool_name: &str) {
+        let mut state = self.state.write();
+        state.tools.remove(tool_name);
+    }
+
+    /// Returns the list of tools in the registry.
+    pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
+        self.state.read().tools.values().cloned().collect()
+    }
+}