Detailed changes
@@ -373,6 +373,7 @@ dependencies = [
"anyhow",
"assets",
"assistant_slash_command",
+ "assistant_tool",
"async-watch",
"cargo_toml",
"chrono",
@@ -454,6 +455,20 @@ dependencies = [
"workspace",
]
+[[package]]
+name = "assistant_tool"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "derive_more",
+ "gpui",
+ "parking_lot",
+ "serde",
+ "serde_json",
+ "workspace",
+]
+
[[package]]
name = "async-attributes"
version = "1.1.2"
@@ -6,6 +6,7 @@ members = [
"crates/assets",
"crates/assistant",
"crates/assistant_slash_command",
+ "crates/assistant_tool",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@@ -181,6 +182,7 @@ anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
+assistant_tool = { path = "crates/assistant_tool" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
breadcrumbs = { path = "crates/breadcrumbs" }
@@ -25,6 +25,7 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
assets.workspace = true
assistant_slash_command.workspace = true
+assistant_tool.workspace = true
async-watch.workspace = true
cargo_toml.workspace = true
chrono.workspace = true
@@ -13,11 +13,13 @@ pub(crate) mod slash_command_picker;
pub mod slash_command_settings;
mod streaming_diff;
mod terminal_inline_assistant;
+mod tools;
mod workflow;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
+use assistant_tool::ToolRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use context::*;
@@ -214,6 +216,7 @@ pub fn init(
prompt_library::init(cx);
init_language_model_settings(cx);
assistant_slash_command::init(cx);
+ assistant_tool::init(cx);
assistant_panel::init(cx);
context_servers::init(cx);
@@ -228,6 +231,7 @@ pub fn init(
.map(Arc::new)
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
register_slash_commands(Some(prompt_builder.clone()), cx);
+ register_tools(cx);
inline_assistant::init(
fs.clone(),
prompt_builder.clone(),
@@ -401,6 +405,11 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
}
}
+fn register_tools(cx: &mut AppContext) {
+ let tool_registry = ToolRegistry::global(cx);
+ tool_registry.register_tool(tools::now_tool::NowTool);
+}
+
pub fn humanize_token_count(count: usize) -> String {
match count {
0..=999 => count.to_string(),
@@ -9,9 +9,11 @@ use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
};
+use assistant_tool::ToolRegistry;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use fs::{Fs, RemoveOptions};
use futures::{
future::{self, Shared},
@@ -27,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- MessageContent, Role,
+ LanguageModelRequestTool, MessageContent, Role,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
@@ -1942,7 +1944,21 @@ impl Context {
// Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
- let request = self.to_completion_request(cx);
+ let mut request = self.to_completion_request(cx);
+
+ if cx.has_flag::<ToolUseFeatureFlag>() {
+ let tool_registry = ToolRegistry::global(cx);
+ request.tools = tool_registry
+ .tools()
+ .into_iter()
+ .map(|tool| LanguageModelRequestTool {
+ name: tool.name(),
+ description: tool.description(),
+ input_schema: tool.input_schema(),
+ })
+ .collect();
+ }
+
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@@ -2788,6 +2804,16 @@ pub enum PendingSlashCommandStatus {
Error(String),
}
+pub(crate) struct ToolUseFeatureFlag;
+
+impl FeatureFlag for ToolUseFeatureFlag {
+ const NAME: &'static str = "assistant-tool-use";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
+
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: String,
@@ -0,0 +1 @@
+pub mod now_tool;
@@ -0,0 +1,60 @@
+use std::sync::Arc;
+
+use anyhow::{anyhow, Result};
+use assistant_tool::Tool;
+use chrono::{Local, Utc};
+use gpui::{Task, WeakView, WindowContext};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+#[serde(rename_all = "snake_case")]
+pub enum Timezone {
+ /// Use UTC for the datetime.
+ Utc,
+ /// Use local time for the datetime.
+ Local,
+}
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct FileToolInput {
+ /// The timezone to use for the datetime.
+ timezone: Timezone,
+}
+
+pub struct NowTool;
+
+impl Tool for NowTool {
+ fn name(&self) -> String {
+ "now".into()
+ }
+
+ fn description(&self) -> String {
+ "Returns the current datetime in RFC 3339 format.".into()
+ }
+
+ fn input_schema(&self) -> serde_json::Value {
+ let schema = schemars::schema_for!(FileToolInput);
+ serde_json::to_value(&schema).unwrap()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ _workspace: WeakView<workspace::Workspace>,
+ _cx: &mut WindowContext,
+ ) -> Task<Result<String>> {
+ let input: FileToolInput = match serde_json::from_value(input) {
+ Ok(input) => input,
+ Err(err) => return Task::ready(Err(anyhow!(err))),
+ };
+
+ let now = match input.timezone {
+ Timezone::Utc => Utc::now().to_rfc3339(),
+ Timezone::Local => Local::now().to_rfc3339(),
+ };
+ let text = format!("The current datetime is {now}.");
+
+ Task::ready(Ok(text))
+ }
+}
@@ -0,0 +1,22 @@
+[package]
+name = "assistant_tool"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/assistant_tool.rs"
+
+[dependencies]
+anyhow.workspace = true
+collections.workspace = true
+derive_more.workspace = true
+gpui.workspace = true
+parking_lot.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+workspace.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,35 @@
+mod tool_registry;
+
+use std::sync::Arc;
+
+use anyhow::Result;
+use gpui::{AppContext, Task, WeakView, WindowContext};
+use workspace::Workspace;
+
+pub use tool_registry::*;
+
+pub fn init(cx: &mut AppContext) {
+ ToolRegistry::default_global(cx);
+}
+
+/// A tool that can be used by a language model.
+pub trait Tool: 'static + Send + Sync {
+ /// Returns the name of the tool.
+ fn name(&self) -> String;
+
+ /// Returns the description of the tool.
+ fn description(&self) -> String;
+
+ /// Returns the JSON schema that describes the tool's input.
+ fn input_schema(&self) -> serde_json::Value {
+ serde_json::Value::Object(serde_json::Map::default())
+ }
+
+ /// Runs the tool with the provided input.
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ workspace: WeakView<Workspace>,
+ cx: &mut WindowContext,
+ ) -> Task<Result<String>>;
+}
@@ -0,0 +1,69 @@
+use std::sync::Arc;
+
+use collections::HashMap;
+use derive_more::{Deref, DerefMut};
+use gpui::Global;
+use gpui::{AppContext, ReadGlobal};
+use parking_lot::RwLock;
+
+use crate::Tool;
+
+#[derive(Default, Deref, DerefMut)]
+struct GlobalToolRegistry(Arc<ToolRegistry>);
+
+impl Global for GlobalToolRegistry {}
+
+#[derive(Default)]
+struct ToolRegistryState {
+ tools: HashMap<Arc<str>, Arc<dyn Tool>>,
+}
+
+#[derive(Default)]
+pub struct ToolRegistry {
+ state: RwLock<ToolRegistryState>,
+}
+
+impl ToolRegistry {
+ /// Returns the global [`ToolRegistry`].
+ pub fn global(cx: &AppContext) -> Arc<Self> {
+ GlobalToolRegistry::global(cx).0.clone()
+ }
+
+ /// Returns the global [`ToolRegistry`].
+ ///
+ /// Inserts a default [`ToolRegistry`] if one does not yet exist.
+ pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
+ cx.default_global::<GlobalToolRegistry>().0.clone()
+ }
+
+ pub fn new() -> Arc<Self> {
+ Arc::new(Self {
+ state: RwLock::new(ToolRegistryState {
+ tools: HashMap::default(),
+ }),
+ })
+ }
+
+ /// Registers the provided [`Tool`].
+ pub fn register_tool(&self, tool: impl Tool) {
+ let mut state = self.state.write();
+ let tool_name: Arc<str> = tool.name().into();
+ state.tools.insert(tool_name, Arc::new(tool));
+ }
+
+ /// Unregisters the provided [`Tool`].
+ pub fn unregister_tool(&self, tool: impl Tool) {
+ self.unregister_tool_by_name(tool.name().as_str())
+ }
+
+ /// Unregisters the tool with the given name.
+ pub fn unregister_tool_by_name(&self, tool_name: &str) {
+ let mut state = self.state.write();
+ state.tools.remove(tool_name);
+ }
+
+ /// Returns the list of tools in the registry.
+ pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
+ self.state.read().tools.values().cloned().collect()
+ }
+}