Detailed changes
@@ -411,10 +411,17 @@ name = "assistant_tooling"
version = "0.1.0"
dependencies = [
"anyhow",
+ "collections",
+ "futures 0.3.28",
"gpui",
+ "project",
"schemars",
"serde",
"serde_json",
+ "settings",
+ "sum_tree",
+ "unindent",
+ "util",
]
[[package]]
@@ -4,10 +4,16 @@ mod completion_provider;
mod tools;
pub mod ui;
+use crate::{
+ attachments::ActiveEditorAttachmentTool,
+ tools::{CreateBufferTool, ProjectIndexTool},
+ ui::UserOrAssistant,
+};
use ::ui::{div, prelude::*, Color, ViewContext};
use anyhow::{Context, Result};
-use assistant_tooling::{ToolFunctionCall, ToolRegistry};
-use attachments::{ActiveEditorAttachmentTool, UserAttachment, UserAttachmentStore};
+use assistant_tooling::{
+ AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
+};
use client::{proto, Client, UserStore};
use collections::HashMap;
use completion_provider::*;
@@ -34,9 +40,6 @@ use workspace::{
pub use assistant_settings::AssistantSettings;
-use crate::tools::{CreateBufferTool, ProjectIndexTool};
-use crate::ui::UserOrAssistant;
-
const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
@@ -85,10 +88,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
});
workspace.register_action(|workspace, _: &DebugProjectIndex, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
- if let Some(index) = panel.read(cx).chat.read(cx).project_index.clone() {
- let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
- workspace.add_item_to_center(Box::new(view), cx);
- }
+ let index = panel.read(cx).chat.read(cx).project_index.clone();
+ let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
+ workspace.add_item_to_center(Box::new(view), cx);
}
});
},
@@ -122,10 +124,7 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new();
tool_registry
- .register(
- ProjectIndexTool::new(project_index.clone(), project.read(cx).fs().clone()),
- cx,
- )
+ .register(ProjectIndexTool::new(project_index.clone()), cx)
.context("failed to register ProjectIndexTool")
.log_err();
tool_registry
@@ -136,7 +135,7 @@ impl AssistantPanel {
.context("failed to register CreateBufferTool")
.log_err();
- let mut attachment_store = UserAttachmentStore::new();
+ let mut attachment_store = AttachmentRegistry::new();
attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx));
Self::new(
@@ -144,7 +143,7 @@ impl AssistantPanel {
Arc::new(tool_registry),
Arc::new(attachment_store),
app_state.user_store.clone(),
- Some(project_index),
+ project_index,
workspace,
cx,
)
@@ -155,9 +154,9 @@ impl AssistantPanel {
pub fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
- attachment_store: Arc<UserAttachmentStore>,
+ attachment_store: Arc<AttachmentRegistry>,
user_store: Model<UserStore>,
- project_index: Option<Model<ProjectIndex>>,
+ project_index: Model<ProjectIndex>,
workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
@@ -241,16 +240,16 @@ pub struct AssistantChat {
list_state: ListState,
language_registry: Arc<LanguageRegistry>,
composer_editor: View<Editor>,
- project_index_button: Option<View<ProjectIndexButton>>,
+ project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>,
user_store: Model<UserStore>,
next_message_id: MessageId,
collapsed_messages: HashMap<MessageId, bool>,
editing_message: Option<EditingMessage>,
pending_completion: Option<Task<()>>,
- attachment_store: Arc<UserAttachmentStore>,
tool_registry: Arc<ToolRegistry>,
- project_index: Option<Model<ProjectIndex>>,
+ attachment_registry: Arc<AttachmentRegistry>,
+ project_index: Model<ProjectIndex>,
}
struct EditingMessage {
@@ -263,9 +262,9 @@ impl AssistantChat {
fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
- attachment_store: Arc<UserAttachmentStore>,
+ attachment_registry: Arc<AttachmentRegistry>,
user_store: Model<UserStore>,
- project_index: Option<Model<ProjectIndex>>,
+ project_index: Model<ProjectIndex>,
workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
@@ -281,14 +280,14 @@ impl AssistantChat {
},
);
- let project_index_button = project_index.clone().map(|project_index| {
- cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
+ let project_index_button = cx.new_view(|cx| {
+ ProjectIndexButton::new(project_index.clone(), tool_registry.clone(), cx)
});
let active_file_button = match workspace.upgrade() {
Some(workspace) => {
Some(cx.new_view(
- |cx| ActiveFileButton::new(attachment_store.clone(), workspace, cx), //
+ |cx| ActiveFileButton::new(attachment_registry.clone(), workspace, cx), //
))
}
_ => None,
@@ -313,7 +312,7 @@ impl AssistantChat {
editing_message: None,
collapsed_messages: HashMap::default(),
pending_completion: None,
- attachment_store,
+ attachment_registry,
tool_registry,
}
}
@@ -395,7 +394,7 @@ impl AssistantChat {
let mode = *mode;
self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
let attachments_task = this.update(&mut cx, |this, cx| {
- let attachment_store = this.attachment_store.clone();
+ let attachment_store = this.attachment_registry.clone();
attachment_store.call_all_attachment_tools(cx)
});
@@ -443,7 +442,7 @@ impl AssistantChat {
let mut call_count = 0;
loop {
let complete = async {
- let completion = this.update(cx, |this, cx| {
+ let (tool_definitions, model_name, messages) = this.update(cx, |this, cx| {
this.push_new_assistant_message(cx);
let definitions = if call_count < limit
@@ -455,14 +454,22 @@ impl AssistantChat {
};
call_count += 1;
- let messages = this.completion_messages(cx);
+ (
+ definitions,
+ this.model.clone(),
+ this.completion_messages(cx),
+ )
+ })?;
+ let messages = messages.await?;
+
+ let completion = cx.update(|cx| {
CompletionProvider::get(cx).complete(
- this.model.clone(),
+ model_name,
messages,
Vec::new(),
1.0,
- definitions,
+ tool_definitions,
)
});
@@ -765,7 +772,12 @@ impl AssistantChat {
}
}
- fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
+ fn completion_messages(&self, cx: &mut WindowContext) -> Task<Result<Vec<CompletionMessage>>> {
+ let project_index = self.project_index.read(cx);
+ let project = project_index.project();
+ let fs = project_index.fs();
+
+ let mut project_context = ProjectContext::new(project, fs);
let mut completion_messages = Vec::new();
for message in &self.messages {
@@ -773,12 +785,11 @@ impl AssistantChat {
ChatMessage::User(UserMessage {
body, attachments, ..
}) => {
- completion_messages.extend(
- attachments
- .into_iter()
- .filter_map(|attachment| attachment.message.clone())
- .map(|content| CompletionMessage::System { content }),
- );
+ for attachment in attachments {
+ if let Some(content) = attachment.generate(&mut project_context, cx) {
+ completion_messages.push(CompletionMessage::System { content });
+ }
+ }
// Show user's message last so that the assistant is grounded in the user's request
completion_messages.push(CompletionMessage::User {
@@ -815,7 +826,9 @@ impl AssistantChat {
for tool_call in tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result {
- Some(result) => result.format(&tool_call.name),
+ Some(result) => {
+ result.generate(&tool_call.name, &mut project_context, cx)
+ }
None => "".to_string(),
};
@@ -828,7 +841,13 @@ impl AssistantChat {
}
}
- completion_messages
+ let system_message = project_context.generate_system_message(cx);
+
+ cx.background_executor().spawn(async move {
+ let content = system_message.await?;
+ completion_messages.insert(0, CompletionMessage::System { content });
+ Ok(completion_messages)
+ })
}
}
@@ -1,137 +1,18 @@
-use std::{
- any::TypeId,
- sync::{
- atomic::{AtomicBool, Ordering::SeqCst},
- Arc,
- },
-};
+pub mod active_file;
use anyhow::{anyhow, Result};
-use collections::HashMap;
+use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
use editor::Editor;
-use futures::future::join_all;
-use gpui::{AnyView, Render, Task, View, WeakView};
+use gpui::{Render, Task, View, WeakModel, WeakView};
+use language::Buffer;
+use project::ProjectPath;
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
-use util::{maybe, ResultExt};
+use util::maybe;
use workspace::Workspace;
-/// A collected attachment from running an attachment tool
-pub struct UserAttachment {
- pub message: Option<String>,
- pub view: AnyView,
-}
-
-pub struct UserAttachmentStore {
- attachment_tools: HashMap<TypeId, DynamicAttachment>,
-}
-
-/// Internal representation of an attachment tool to allow us to treat them dynamically
-struct DynamicAttachment {
- enabled: AtomicBool,
- call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
-}
-
-impl UserAttachmentStore {
- pub fn new() -> Self {
- Self {
- attachment_tools: HashMap::default(),
- }
- }
-
- pub fn register<A: AttachmentTool + 'static>(&mut self, attachment: A) {
- let call = Box::new(move |cx: &mut WindowContext| {
- let result = attachment.run(cx);
-
- cx.spawn(move |mut cx| async move {
- let result: Result<A::Output> = result.await;
- let message = A::format(&result);
- let view = cx.update(|cx| A::view(result, cx))?;
-
- Ok(UserAttachment {
- message,
- view: view.into(),
- })
- })
- });
-
- self.attachment_tools.insert(
- TypeId::of::<A>(),
- DynamicAttachment {
- call,
- enabled: AtomicBool::new(true),
- },
- );
- }
-
- pub fn set_attachment_tool_enabled<A: AttachmentTool + 'static>(&self, is_enabled: bool) {
- if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
- attachment.enabled.store(is_enabled, SeqCst);
- }
- }
-
- pub fn is_attachment_tool_enabled<A: AttachmentTool + 'static>(&self) -> bool {
- if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
- attachment.enabled.load(SeqCst)
- } else {
- false
- }
- }
-
- pub fn call<A: AttachmentTool + 'static>(
- &self,
- cx: &mut WindowContext,
- ) -> Task<Result<UserAttachment>> {
- let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) else {
- return Task::ready(Err(anyhow!("no attachment tool")));
- };
-
- (attachment.call)(cx)
- }
-
- pub fn call_all_attachment_tools(
- self: Arc<Self>,
- cx: &mut WindowContext<'_>,
- ) -> Task<Result<Vec<UserAttachment>>> {
- let this = self.clone();
- cx.spawn(|mut cx| async move {
- let attachment_tasks = cx.update(|cx| {
- let mut tasks = Vec::new();
- for attachment in this
- .attachment_tools
- .values()
- .filter(|attachment| attachment.enabled.load(SeqCst))
- {
- tasks.push((attachment.call)(cx))
- }
-
- tasks
- })?;
-
- let attachments = join_all(attachment_tasks.into_iter()).await;
-
- Ok(attachments
- .into_iter()
- .filter_map(|attachment| attachment.log_err())
- .collect())
- })
- }
-}
-
-pub trait AttachmentTool {
- type Output: 'static;
- type View: Render;
-
- fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
- fn format(output: &Result<Self::Output>) -> Option<String>;
-
- fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
-}
-
pub struct ActiveEditorAttachment {
- filename: Arc<str>,
- language: Arc<str>,
- text: Arc<str>,
+ buffer: WeakModel<Buffer>,
+ path: Option<ProjectPath>,
}
pub struct FileAttachmentView {
@@ -142,7 +23,13 @@ impl Render for FileAttachmentView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
match &self.output {
Ok(attachment) => {
- let filename = attachment.filename.clone();
+ let filename: SharedString = attachment
+ .path
+ .as_ref()
+ .and_then(|p| p.path.file_name()?.to_str())
+ .unwrap_or("Untitled")
+ .to_string()
+ .into();
// todo!(): make the button link to the actual file to open
ButtonLike::new("file-attachment")
@@ -152,7 +39,7 @@ impl Render for FileAttachmentView {
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(ui::Icon::new(IconName::File))
- .child(filename.to_string()),
+ .child(filename.clone()),
)
.tooltip({
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
@@ -164,6 +51,20 @@ impl Render for FileAttachmentView {
}
}
+impl ToolOutput for FileAttachmentView {
+ fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
+ if let Ok(result) = &self.output {
+ if let Some(path) = &result.path {
+ project.add_file(path.clone());
+ return format!("current file: {}", path.path.display());
+ } else if let Some(buffer) = result.buffer.upgrade() {
+ return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
+ }
+ }
+ String::new()
+ }
+}
+
pub struct ActiveEditorAttachmentTool {
workspace: WeakView<Workspace>,
}
@@ -174,7 +75,7 @@ impl ActiveEditorAttachmentTool {
}
}
-impl AttachmentTool for ActiveEditorAttachmentTool {
+impl LanguageModelAttachment for ActiveEditorAttachmentTool {
type Output = ActiveEditorAttachment;
type View = FileAttachmentView;
@@ -191,47 +92,22 @@ impl AttachmentTool for ActiveEditorAttachmentTool {
let buffer = active_buffer.read(cx);
- if let Some(singleton) = buffer.as_singleton() {
- let singleton = singleton.read(cx);
-
- let filename = singleton
- .file()
- .map(|file| file.path().to_string_lossy())
- .unwrap_or("Untitled".into());
-
- let text = singleton.text();
-
- let language = singleton
- .language()
- .map(|l| {
- let name = l.code_fence_block_name();
- name.to_string()
- })
- .unwrap_or_default();
-
+ if let Some(buffer) = buffer.as_singleton() {
+ let path =
+ project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path.clone(),
+ });
return Ok(ActiveEditorAttachment {
- filename: filename.into(),
- language: language.into(),
- text: text.into(),
+ buffer: buffer.downgrade(),
+ path,
});
+ } else {
+ Err(anyhow!("no active buffer"))
}
-
- Err(anyhow!("no active buffer"))
}))
}
- fn format(output: &Result<Self::Output>) -> Option<String> {
- let output = output.as_ref().ok()?;
-
- let filename = &output.filename;
- let language = &output.language;
- let text = &output.text;
-
- Some(format!(
- "User's active file `{filename}`:\n\n```{language}\n{text}```\n\n"
- ))
- }
-
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
cx.new_view(|_cx| FileAttachmentView { output })
}
@@ -0,0 +1 @@
+
@@ -1,5 +1,5 @@
use anyhow::Result;
-use assistant_tooling::LanguageModelTool;
+use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
use editor::Editor;
use gpui::{prelude::*, Model, Task, View, WeakView};
use project::Project;
@@ -31,11 +31,9 @@ pub struct CreateBufferInput {
language: String,
}
-pub struct CreateBufferOutput {}
-
impl LanguageModelTool for CreateBufferTool {
type Input = CreateBufferInput;
- type Output = CreateBufferOutput;
+ type Output = ();
type View = CreateBufferView;
fn name(&self) -> String {
@@ -83,32 +81,39 @@ impl LanguageModelTool for CreateBufferTool {
})
.log_err();
- Ok(CreateBufferOutput {})
+ Ok(())
}
})
}
- fn format(input: &Self::Input, output: &Result<Self::Output>) -> String {
- match output {
- Ok(_) => format!("Created a new {} buffer", input.language),
- Err(err) => format!("Failed to create buffer: {err:?}"),
- }
- }
-
fn output_view(
- _tool_call_id: String,
- _input: Self::Input,
- _output: Result<Self::Output>,
+ input: Self::Input,
+ output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View> {
- cx.new_view(|_cx| CreateBufferView {})
+ cx.new_view(|_cx| CreateBufferView {
+ language: input.language,
+ output,
+ })
}
}
-pub struct CreateBufferView {}
+pub struct CreateBufferView {
+ language: String,
+ output: Result<()>,
+}
impl Render for CreateBufferView {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
div().child("Opening a buffer")
}
}
+
+impl ToolOutput for CreateBufferView {
+ fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
+ match &self.output {
+ Ok(_) => format!("Created a new {} buffer", self.language),
+ Err(err) => format!("Failed to create buffer: {err:?}"),
+ }
+ }
+}
@@ -1,25 +1,18 @@
use anyhow::Result;
-use assistant_tooling::LanguageModelTool;
+use assistant_tooling::{LanguageModelTool, ToolOutput};
+use collections::BTreeMap;
use gpui::{prelude::*, Model, Task};
-use project::Fs;
+use project::ProjectPath;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
use serde::Deserialize;
-use std::{collections::HashSet, sync::Arc};
-
-use ui::{
- div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
- WindowContext,
-};
-use util::ResultExt as _;
+use std::{fmt::Write as _, ops::Range};
+use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
const DEFAULT_SEARCH_LIMIT: usize = 20;
-#[derive(Clone)]
-pub struct CodebaseExcerpt {
- path: SharedString,
- text: SharedString,
- score: f32,
+pub struct ProjectIndexTool {
+ project_index: Model<ProjectIndex>,
}
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
@@ -40,6 +33,11 @@ pub struct ProjectIndexView {
expanded_header: bool,
}
+pub struct ProjectIndexOutput {
+ status: Status,
+ excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
+}
+
impl ProjectIndexView {
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
let element_id = ElementId::Name(nanoid::nanoid!().into());
@@ -71,19 +69,15 @@ impl Render for ProjectIndexView {
Ok(output) => output,
};
- let num_files_searched = output.files_searched.len();
+ let file_count = output.excerpts.len();
let header = h_flex()
.gap_2()
.child(Icon::new(IconName::File))
.child(format!(
"Read {} {}",
- num_files_searched,
- if num_files_searched == 1 {
- "file"
- } else {
- "files"
- }
+ file_count,
+ if file_count == 1 { "file" } else { "files" }
));
v_flex().gap_3().child(
@@ -102,36 +96,50 @@ impl Render for ProjectIndexView {
.child(Icon::new(IconName::MagnifyingGlass))
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
)
- .child(v_flex().gap_2().children(output.files_searched.iter().map(
- |path| {
- h_flex()
- .gap_2()
- .child(Icon::new(IconName::File))
- .child(Label::new(path.clone()).color(Color::Muted))
- },
- ))),
+ .child(
+ v_flex()
+ .gap_2()
+ .children(output.excerpts.keys().map(|path| {
+ h_flex().gap_2().child(Icon::new(IconName::File)).child(
+ Label::new(path.path.to_string_lossy().to_string())
+ .color(Color::Muted),
+ )
+ })),
+ ),
),
)
}
}
-pub struct ProjectIndexTool {
- project_index: Model<ProjectIndex>,
- fs: Arc<dyn Fs>,
-}
+impl ToolOutput for ProjectIndexView {
+ fn generate(
+ &self,
+ context: &mut assistant_tooling::ProjectContext,
+ _: &mut WindowContext,
+ ) -> String {
+ match &self.output {
+ Ok(output) => {
+ let mut body = "found results in the following paths:\n".to_string();
-pub struct ProjectIndexOutput {
- excerpts: Vec<CodebaseExcerpt>,
- status: Status,
- files_searched: HashSet<SharedString>,
+ for (project_path, ranges) in &output.excerpts {
+ context.add_excerpts(project_path.clone(), ranges);
+ writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
+ }
+
+ if output.status != Status::Idle {
+ body.push_str("Still indexing. Results may be incomplete.\n");
+ }
+
+ body
+ }
+ Err(err) => format!("Error: {}", err),
+ }
+ }
}
impl ProjectIndexTool {
- pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
- // Listen for project index status and update the ProjectIndexTool directly
-
- // TODO: setup a better description based on the user's current codebase.
- Self { project_index, fs }
+ pub fn new(project_index: Model<ProjectIndex>) -> Self {
+ Self { project_index }
}
}
@@ -151,64 +159,42 @@ impl LanguageModelTool for ProjectIndexTool {
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx);
let status = project_index.status();
- let results = project_index.search(
+ let search = project_index.search(
query.query.clone(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx,
);
- let fs = self.fs.clone();
-
- cx.spawn(|cx| async move {
- let results = results.await?;
-
- let excerpts = results.into_iter().map(|result| {
- let abs_path = result
- .worktree
- .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
- let fs = fs.clone();
-
- async move {
- let path = result.path.clone();
- let text = fs.load(&abs_path?).await?;
-
- let mut start = result.range.start;
- let mut end = result.range.end.min(text.len());
- while !text.is_char_boundary(start) {
- start += 1;
- }
- while !text.is_char_boundary(end) {
- end -= 1;
- }
-
- anyhow::Ok(CodebaseExcerpt {
- path: path.to_string_lossy().to_string().into(),
- text: SharedString::from(text[start..end].to_string()),
- score: result.score,
- })
+ cx.spawn(|mut cx| async move {
+ let search_results = search.await?;
+
+ cx.update(|cx| {
+ let mut output = ProjectIndexOutput {
+ status,
+ excerpts: Default::default(),
+ };
+
+ for search_result in search_results {
+ let path = ProjectPath {
+ worktree_id: search_result.worktree.read(cx).id(),
+ path: search_result.path.clone(),
+ };
+
+ let excerpts_for_path = output.excerpts.entry(path).or_default();
+ let ix = match excerpts_for_path
+ .binary_search_by_key(&search_result.range.start, |r| r.start)
+ {
+ Ok(ix) | Err(ix) => ix,
+ };
+ excerpts_for_path.insert(ix, search_result.range);
}
- });
-
- let mut files_searched = HashSet::new();
- let excerpts = futures::future::join_all(excerpts)
- .await
- .into_iter()
- .filter_map(|result| result.log_err())
- .inspect(|excerpt| {
- files_searched.insert(excerpt.path.clone());
- })
- .collect::<Vec<_>>();
-
- anyhow::Ok(ProjectIndexOutput {
- excerpts,
- status,
- files_searched,
+
+ output
})
})
}
fn output_view(
- _tool_call_id: String,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
@@ -220,34 +206,4 @@ impl LanguageModelTool for ProjectIndexTool {
CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
.start_slot("Searching code base")
}
-
- fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
- match &output {
- Ok(output) => {
- let mut body = "Semantic search results:\n".to_string();
-
- if output.status != Status::Idle {
- body.push_str("Still indexing. Results may be incomplete.\n");
- }
-
- if output.excerpts.is_empty() {
- body.push_str("No results found");
- return body;
- }
-
- for excerpt in &output.excerpts {
- body.push_str("Excerpt from ");
- body.push_str(excerpt.path.as_ref());
- body.push_str(", score ");
- body.push_str(&excerpt.score.to_string());
- body.push_str(":\n");
- body.push_str("~~~\n");
- body.push_str(excerpt.text.as_ref());
- body.push_str("~~~\n");
- }
- body
- }
- Err(err) => format!("Error: {}", err),
- }
- }
}
@@ -1,4 +1,5 @@
-use crate::attachments::{ActiveEditorAttachmentTool, UserAttachmentStore};
+use crate::attachments::ActiveEditorAttachmentTool;
+use assistant_tooling::AttachmentRegistry;
use editor::Editor;
use gpui::{prelude::*, Subscription, View};
use std::sync::Arc;
@@ -13,7 +14,7 @@ enum Status {
}
pub struct ActiveFileButton {
- attachment_store: Arc<UserAttachmentStore>,
+ attachment_registry: Arc<AttachmentRegistry>,
status: Status,
#[allow(dead_code)]
workspace_subscription: Subscription,
@@ -21,7 +22,7 @@ pub struct ActiveFileButton {
impl ActiveFileButton {
pub fn new(
- attachment_store: Arc<UserAttachmentStore>,
+ attachment_store: Arc<AttachmentRegistry>,
workspace: View<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
@@ -30,14 +31,14 @@ impl ActiveFileButton {
cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx));
Self {
- attachment_store,
+ attachment_registry: attachment_store,
status: Status::NoFile,
workspace_subscription,
}
}
pub fn set_enabled(&mut self, enabled: bool) {
- self.attachment_store
+ self.attachment_registry
.set_attachment_tool_enabled::<ActiveEditorAttachmentTool>(enabled);
}
@@ -79,7 +80,7 @@ impl ActiveFileButton {
impl Render for ActiveFileButton {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let is_enabled = self
- .attachment_store
+ .attachment_registry
.is_attachment_tool_enabled::<ActiveEditorAttachmentTool>();
let icon = if is_enabled {
@@ -11,7 +11,7 @@ use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, Divider, TextSize, T
#[derive(IntoElement)]
pub struct Composer {
editor: View<Editor>,
- project_index_button: Option<View<ProjectIndexButton>>,
+ project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>,
model_selector: AnyElement,
}
@@ -19,7 +19,7 @@ pub struct Composer {
impl Composer {
pub fn new(
editor: View<Editor>,
- project_index_button: Option<View<ProjectIndexButton>>,
+ project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>,
model_selector: AnyElement,
) -> Self {
@@ -32,11 +32,7 @@ impl Composer {
}
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
- h_flex().children(
- self.project_index_button
- .clone()
- .map(|view| view.into_any_element()),
- )
+ h_flex().child(self.project_index_button.clone())
}
fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
@@ -13,10 +13,18 @@ path = "src/assistant_tooling.rs"
[dependencies]
anyhow.workspace = true
+collections.workspace = true
+futures.workspace = true
gpui.workspace = true
+project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
+sum_tree.workspace = true
+util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }
+unindent.workspace = true
@@ -1,5 +1,9 @@
-pub mod registry;
-pub mod tool;
+mod attachment_registry;
+mod project_context;
+mod tool_registry;
-pub use crate::registry::ToolRegistry;
-pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};
+pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
+pub use project_context::ProjectContext;
+pub use tool_registry::{
+ LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry,
+};
@@ -0,0 +1,148 @@
+use crate::{ProjectContext, ToolOutput};
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use futures::future::join_all;
+use gpui::{AnyView, Render, Task, View, WindowContext};
+use std::{
+ any::TypeId,
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ },
+};
+use util::ResultExt as _;
+
+pub struct AttachmentRegistry {
+ registered_attachments: HashMap<TypeId, RegisteredAttachment>,
+}
+
+pub trait LanguageModelAttachment {
+ type Output: 'static;
+ type View: Render + ToolOutput;
+
+ fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
+
+ fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
+}
+
+/// A collected attachment from running an attachment tool
+pub struct UserAttachment {
+ pub view: AnyView,
+ generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
+}
+
+/// Internal representation of an attachment tool to allow us to treat them dynamically
+struct RegisteredAttachment {
+ enabled: AtomicBool,
+ call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
+}
+
+impl AttachmentRegistry {
+ pub fn new() -> Self {
+ Self {
+ registered_attachments: HashMap::default(),
+ }
+ }
+
+ pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
+ let call = Box::new(move |cx: &mut WindowContext| {
+ let result = attachment.run(cx);
+
+ cx.spawn(move |mut cx| async move {
+ let result: Result<A::Output> = result.await;
+ let view = cx.update(|cx| A::view(result, cx))?;
+
+ Ok(UserAttachment {
+ view: view.into(),
+ generate_fn: generate::<A>,
+ })
+ })
+ });
+
+ self.registered_attachments.insert(
+ TypeId::of::<A>(),
+ RegisteredAttachment {
+ call,
+ enabled: AtomicBool::new(true),
+ },
+ );
+ return;
+
+ fn generate<T: LanguageModelAttachment>(
+ view: AnyView,
+ project: &mut ProjectContext,
+ cx: &mut WindowContext,
+ ) -> String {
+ view.downcast::<T::View>()
+ .unwrap()
+ .update(cx, |view, cx| T::View::generate(view, project, cx))
+ }
+ }
+
+ pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
+ &self,
+ is_enabled: bool,
+ ) {
+ if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
+ attachment.enabled.store(is_enabled, SeqCst);
+ }
+ }
+
+ pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
+ if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
+ attachment.enabled.load(SeqCst)
+ } else {
+ false
+ }
+ }
+
+ pub fn call<A: LanguageModelAttachment + 'static>(
+ &self,
+ cx: &mut WindowContext,
+ ) -> Task<Result<UserAttachment>> {
+ let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
+ return Task::ready(Err(anyhow!("no attachment tool")));
+ };
+
+ (attachment.call)(cx)
+ }
+
+ pub fn call_all_attachment_tools(
+ self: Arc<Self>,
+ cx: &mut WindowContext<'_>,
+ ) -> Task<Result<Vec<UserAttachment>>> {
+ let this = self.clone();
+ cx.spawn(|mut cx| async move {
+ let attachment_tasks = cx.update(|cx| {
+ let mut tasks = Vec::new();
+ for attachment in this
+ .registered_attachments
+ .values()
+ .filter(|attachment| attachment.enabled.load(SeqCst))
+ {
+ tasks.push((attachment.call)(cx))
+ }
+
+ tasks
+ })?;
+
+ let attachments = join_all(attachment_tasks.into_iter()).await;
+
+ Ok(attachments
+ .into_iter()
+ .filter_map(|attachment| attachment.log_err())
+ .collect())
+ })
+ }
+}
+
+impl UserAttachment {
+ pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
+ let result = (self.generate_fn)(self.view.clone(), output, cx);
+ if result.is_empty() {
+ None
+ } else {
+ Some(result)
+ }
+ }
+}
@@ -0,0 +1,296 @@
+use anyhow::{anyhow, Result};
+use gpui::{AppContext, Model, Task, WeakModel};
+use project::{Fs, Project, ProjectPath, Worktree};
+use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
+use sum_tree::TreeMap;
+
+pub struct ProjectContext {
+ files: TreeMap<ProjectPath, PathState>,
+ project: WeakModel<Project>,
+ fs: Arc<dyn Fs>,
+}
+
+#[derive(Debug, Clone)]
+enum PathState {
+ PathOnly,
+ EntireFile,
+ Excerpts { ranges: Vec<Range<usize>> },
+}
+
+impl ProjectContext {
+ pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
+ Self {
+ files: TreeMap::default(),
+ fs,
+ project,
+ }
+ }
+
+ pub fn add_path(&mut self, project_path: ProjectPath) {
+ if self.files.get(&project_path).is_none() {
+ self.files.insert(project_path, PathState::PathOnly);
+ }
+ }
+
+ pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
+ let previous_state = self
+ .files
+ .get(&project_path)
+ .unwrap_or(&PathState::PathOnly);
+
+ let mut ranges = match previous_state {
+ PathState::EntireFile => return,
+ PathState::PathOnly => Vec::new(),
+ PathState::Excerpts { ranges } => ranges.to_vec(),
+ };
+
+ for new_range in new_ranges {
+ let ix = ranges.binary_search_by(|probe| {
+ if probe.end < new_range.start {
+ Ordering::Less
+ } else if probe.start > new_range.end {
+ Ordering::Greater
+ } else {
+ Ordering::Equal
+ }
+ });
+
+ match ix {
+ Ok(mut ix) => {
+ let existing = &mut ranges[ix];
+ existing.start = existing.start.min(new_range.start);
+ existing.end = existing.end.max(new_range.end);
+ while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
+ ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
+ ranges.remove(ix + 1);
+ }
+ while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
+ ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
+ ranges.remove(ix - 1);
+ ix -= 1;
+ }
+ }
+ Err(ix) => {
+ ranges.insert(ix, new_range.clone());
+ }
+ }
+ }
+
+ self.files
+ .insert(project_path, PathState::Excerpts { ranges });
+ }
+
+ pub fn add_file(&mut self, project_path: ProjectPath) {
+ self.files.insert(project_path, PathState::EntireFile);
+ }
+
+ pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
+ let project = self
+ .project
+ .upgrade()
+ .ok_or_else(|| anyhow!("project dropped"));
+ let files = self.files.clone();
+ let fs = self.fs.clone();
+ cx.spawn(|cx| async move {
+ let project = project?;
+ let mut result = "project structure:\n".to_string();
+
+ let mut last_worktree: Option<Model<Worktree>> = None;
+ for (project_path, path_state) in files.iter() {
+ if let Some(worktree) = &last_worktree {
+ if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
+ last_worktree = None;
+ }
+ }
+
+ let worktree;
+ if let Some(last_worktree) = &last_worktree {
+ worktree = last_worktree.clone();
+ } else if let Some(tree) = project.read_with(&cx, |project, cx| {
+ project.worktree_for_id(project_path.worktree_id, cx)
+ })? {
+ worktree = tree;
+ last_worktree = Some(worktree.clone());
+ let worktree_name =
+ worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
+ writeln!(&mut result, "# {}", worktree_name).unwrap();
+ } else {
+ continue;
+ }
+
+ let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
+ let path = &project_path.path;
+ writeln!(&mut result, "## {}", path.display()).unwrap();
+
+ match path_state {
+ PathState::PathOnly => {}
+ PathState::EntireFile => {
+ let text = fs.load(&worktree_abs_path.join(&path)).await?;
+ writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
+ }
+ PathState::Excerpts { ranges } => {
+ let text = fs.load(&worktree_abs_path.join(&path)).await?;
+
+ writeln!(&mut result, "~~~").unwrap();
+
+ // Assumption: ranges are in order, not overlapping
+ let mut prev_range_end = 0;
+ for range in ranges {
+ if range.start > prev_range_end {
+ writeln!(&mut result, "...").unwrap();
+ prev_range_end = range.end;
+ }
+
+ let mut start = range.start;
+ let mut end = range.end.min(text.len());
+ while !text.is_char_boundary(start) {
+ start += 1;
+ }
+ while !text.is_char_boundary(end) {
+ end -= 1;
+ }
+ result.push_str(&text[start..end]);
+ if !result.ends_with('\n') {
+ result.push('\n');
+ }
+ }
+
+ if prev_range_end < text.len() {
+ writeln!(&mut result, "...").unwrap();
+ }
+
+ writeln!(&mut result, "~~~").unwrap();
+ }
+ }
+ }
+ Ok(result)
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::path::Path;
+
+ use super::*;
+ use gpui::TestAppContext;
+ use project::FakeFs;
+ use serde_json::json;
+ use settings::SettingsStore;
+
+ use unindent::Unindent as _;
+
+ #[gpui::test]
+ async fn test_system_message_generation(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let file_3_contents = r#"
+ fn test1() {}
+ fn test2() {}
+ fn test3() {}
+ "#
+ .unindent();
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/code",
+ json!({
+ "root1": {
+ "lib": {
+ "file1.rs": "mod example;",
+ "file2.rs": "",
+ },
+ "test": {
+ "file3.rs": file_3_contents,
+ }
+ },
+ "root2": {
+ "src": {
+ "main.rs": ""
+ }
+ }
+ }),
+ )
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ ["/code/root1".as_ref(), "/code/root2".as_ref()],
+ cx,
+ )
+ .await;
+
+ let worktree_ids = project.read_with(cx, |project, cx| {
+ project
+ .worktrees()
+ .map(|worktree| worktree.read(cx).id())
+ .collect::<Vec<_>>()
+ });
+
+ let mut ax = ProjectContext::new(project.downgrade(), fs);
+
+ ax.add_file(ProjectPath {
+ worktree_id: worktree_ids[0],
+ path: Path::new("lib/file1.rs").into(),
+ });
+
+ let message = cx
+ .update(|cx| ax.generate_system_message(cx))
+ .await
+ .unwrap();
+ assert_eq!(
+ r#"
+ project structure:
+ # root1
+ ## lib/file1.rs
+ ~~~
+ mod example;
+ ~~~
+ "#
+ .unindent(),
+ message
+ );
+
+ ax.add_excerpts(
+ ProjectPath {
+ worktree_id: worktree_ids[0],
+ path: Path::new("test/file3.rs").into(),
+ },
+ &[
+ file_3_contents.find("fn test2").unwrap()
+ ..file_3_contents.find("fn test3").unwrap(),
+ ],
+ );
+
+ let message = cx
+ .update(|cx| ax.generate_system_message(cx))
+ .await
+ .unwrap();
+ assert_eq!(
+ r#"
+ project structure:
+ # root1
+ ## lib/file1.rs
+ ~~~
+ mod example;
+ ~~~
+ ## test/file3.rs
+ ~~~
+ ...
+ fn test2() {}
+ ...
+ ~~~
+ "#
+ .unindent(),
+ message
+ );
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ });
+ }
+}
@@ -1,111 +0,0 @@
-use anyhow::Result;
-use gpui::{div, AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
-use schemars::{schema::RootSchema, schema_for, JsonSchema};
-use serde::Deserialize;
-use std::fmt::Display;
-
-#[derive(Default, Deserialize)]
-pub struct ToolFunctionCall {
- pub id: String,
- pub name: String,
- pub arguments: String,
- #[serde(skip)]
- pub result: Option<ToolFunctionCallResult>,
-}
-
-pub enum ToolFunctionCallResult {
- NoSuchTool,
- ParsingFailed,
- Finished { for_model: String, view: AnyView },
-}
-
-impl ToolFunctionCallResult {
- pub fn format(&self, name: &String) -> String {
- match self {
- ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
- ToolFunctionCallResult::ParsingFailed => {
- format!("Unable to parse arguments for {name}")
- }
- ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
- }
- }
-
- pub fn into_any_element(&self, name: &String) -> AnyElement {
- match self {
- ToolFunctionCallResult::NoSuchTool => {
- format!("Language Model attempted to call {name}").into_any_element()
- }
- ToolFunctionCallResult::ParsingFailed => {
- format!("Language Model called {name} with bad arguments").into_any_element()
- }
- ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
- }
- }
-}
-
-#[derive(Clone)]
-pub struct ToolFunctionDefinition {
- pub name: String,
- pub description: String,
- pub parameters: RootSchema,
-}
-
-impl Display for ToolFunctionDefinition {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let schema = serde_json::to_string(&self.parameters).ok();
- let schema = schema.unwrap_or("None".to_string());
- write!(f, "Name: {}:\n", self.name)?;
- write!(f, "Description: {}\n", self.description)?;
- write!(f, "Parameters: {}", schema)
- }
-}
-
-pub trait LanguageModelTool {
- /// The input type that will be passed in to `execute` when the tool is called
- /// by the language model.
- type Input: for<'de> Deserialize<'de> + JsonSchema;
-
- /// The output returned by executing the tool.
- type Output: 'static;
-
- type View: Render;
-
- /// Returns the name of the tool.
- ///
- /// This name is exposed to the language model to allow the model to pick
- /// which tools to use. As this name is used to identify the tool within a
- /// tool registry, it should be unique.
- fn name(&self) -> String;
-
- /// Returns the description of the tool.
- ///
- /// This can be used to _prompt_ the model as to what the tool does.
- fn description(&self) -> String;
-
- /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
- fn definition(&self) -> ToolFunctionDefinition {
- let root_schema = schema_for!(Self::Input);
-
- ToolFunctionDefinition {
- name: self.name(),
- description: self.description(),
- parameters: root_schema,
- }
- }
-
- /// Executes the tool with the given input.
- fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
- fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
-
- fn output_view(
- tool_call_id: String,
- input: Self::Input,
- output: Result<Self::Output>,
- cx: &mut WindowContext,
- ) -> View<Self::View>;
-
- fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
- div()
- }
-}
@@ -1,54 +1,115 @@
use anyhow::{anyhow, Result};
-use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext};
+use gpui::{
+ div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
+};
+use schemars::{schema::RootSchema, schema_for, JsonSchema};
+use serde::Deserialize;
use std::{
any::TypeId,
collections::HashMap,
+ fmt::Display,
sync::atomic::{AtomicBool, Ordering::SeqCst},
};
-use crate::tool::{
- LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
-};
+use crate::ProjectContext;
-// Internal Tool representation for the registry
-pub struct Tool {
- enabled: AtomicBool,
- type_id: TypeId,
- call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
- render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
- definition: ToolFunctionDefinition,
+pub struct ToolRegistry {
+ registered_tools: HashMap<String, RegisteredTool>,
}
-impl Tool {
- fn new(
- type_id: TypeId,
- call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
- render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
- definition: ToolFunctionDefinition,
- ) -> Self {
- Self {
- enabled: AtomicBool::new(true),
- type_id,
- call,
- render_running,
- definition,
+#[derive(Default, Deserialize)]
+pub struct ToolFunctionCall {
+ pub id: String,
+ pub name: String,
+ pub arguments: String,
+ #[serde(skip)]
+ pub result: Option<ToolFunctionCallResult>,
+}
+
+pub enum ToolFunctionCallResult {
+ NoSuchTool,
+ ParsingFailed,
+ Finished {
+ view: AnyView,
+ generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
+ },
+}
+
+#[derive(Clone)]
+pub struct ToolFunctionDefinition {
+ pub name: String,
+ pub description: String,
+ pub parameters: RootSchema,
+}
+
+pub trait LanguageModelTool {
+ /// The input type that will be passed in to `execute` when the tool is called
+ /// by the language model.
+ type Input: for<'de> Deserialize<'de> + JsonSchema;
+
+ /// The output returned by executing the tool.
+ type Output: 'static;
+
+ type View: Render + ToolOutput;
+
+ /// Returns the name of the tool.
+ ///
+ /// This name is exposed to the language model to allow the model to pick
+ /// which tools to use. As this name is used to identify the tool within a
+ /// tool registry, it should be unique.
+ fn name(&self) -> String;
+
+ /// Returns the description of the tool.
+ ///
+ /// This can be used to _prompt_ the model as to what the tool does.
+ fn description(&self) -> String;
+
+ /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
+ fn definition(&self) -> ToolFunctionDefinition {
+ let root_schema = schema_for!(Self::Input);
+
+ ToolFunctionDefinition {
+ name: self.name(),
+ description: self.description(),
+ parameters: root_schema,
}
}
+
+ /// Executes the tool with the given input.
+ fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
+
+ fn output_view(
+ input: Self::Input,
+ output: Result<Self::Output>,
+ cx: &mut WindowContext,
+ ) -> View<Self::View>;
+
+ fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
+ div()
+ }
}
-pub struct ToolRegistry {
- tools: HashMap<String, Tool>,
+pub trait ToolOutput: Sized {
+ fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+}
+
+struct RegisteredTool {
+ enabled: AtomicBool,
+ type_id: TypeId,
+ call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+ render_running: fn(&mut WindowContext) -> gpui::AnyElement,
+ definition: ToolFunctionDefinition,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
- tools: HashMap::new(),
+ registered_tools: HashMap::new(),
}
}
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
- for tool in self.tools.values() {
+ for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst);
return;
@@ -57,7 +118,7 @@ impl ToolRegistry {
}
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
- for tool in self.tools.values() {
+ for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst);
}
@@ -66,7 +127,7 @@ impl ToolRegistry {
}
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
- self.tools
+ self.registered_tools
.values()
.filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone())
@@ -84,7 +145,7 @@ impl ToolRegistry {
.child(result.into_any_element(&tool_call.name))
.into_any_element(),
None => self
- .tools
+ .registered_tools
.get(&tool_call.name)
.map(|tool| (tool.render_running)(cx))
.unwrap_or_else(|| div().into_any_element()),
@@ -96,13 +157,12 @@ impl ToolRegistry {
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
- let definition = tool.definition();
-
let name = tool.name();
-
- let registered_tool = Tool::new(
- TypeId::of::<T>(),
- Box::new(
+ let registered_tool = RegisteredTool {
+ type_id: TypeId::of::<T>(),
+ definition: tool.definition(),
+ enabled: AtomicBool::new(true),
+ call: Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
@@ -121,8 +181,7 @@ impl ToolRegistry {
cx.spawn(move |mut cx| async move {
let result: Result<T::Output> = result.await;
- let for_model = T::format(&input, &result);
- let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
+ let view = cx.update(|cx| T::output_view(input, result, cx))?;
Ok(ToolFunctionCall {
id,
@@ -130,23 +189,35 @@ impl ToolRegistry {
arguments,
result: Some(ToolFunctionCallResult::Finished {
view: view.into(),
- for_model,
+ generate_fn: generate::<T>,
}),
})
})
},
),
- Box::new(|cx| T::render_running(cx).into_any_element()),
- definition,
- );
-
- let previous = self.tools.insert(name.clone(), registered_tool);
+ render_running: render_running::<T>,
+ };
+ let previous = self.registered_tools.insert(name.clone(), registered_tool);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
- Ok(())
+ return Ok(());
+
+ fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
+ T::render_running(cx).into_any_element()
+ }
+
+ fn generate<T: LanguageModelTool>(
+ view: AnyView,
+ project: &mut ProjectContext,
+ cx: &mut WindowContext,
+ ) -> String {
+ view.downcast::<T::View>()
+ .unwrap()
+ .update(cx, |view, cx| T::View::generate(view, project, cx))
+ }
}
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
@@ -159,7 +230,7 @@ impl ToolRegistry {
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
- let tool = match self.tools.get(&name) {
+ let tool = match self.registered_tools.get(&name) {
Some(tool) => tool,
None => {
let name = name.clone();
@@ -176,6 +247,47 @@ impl ToolRegistry {
}
}
+impl ToolFunctionCallResult {
+ pub fn generate(
+ &self,
+ name: &String,
+ project: &mut ProjectContext,
+ cx: &mut WindowContext,
+ ) -> String {
+ match self {
+ ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
+ ToolFunctionCallResult::ParsingFailed => {
+ format!("Unable to parse arguments for {name}")
+ }
+ ToolFunctionCallResult::Finished { generate_fn, view } => {
+ (generate_fn)(view.clone(), project, cx)
+ }
+ }
+ }
+
+ fn into_any_element(&self, name: &String) -> AnyElement {
+ match self {
+ ToolFunctionCallResult::NoSuchTool => {
+ format!("Language Model attempted to call {name}").into_any_element()
+ }
+ ToolFunctionCallResult::ParsingFailed => {
+ format!("Language Model called {name} with bad arguments").into_any_element()
+ }
+ ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
+ }
+ }
+}
+
+impl Display for ToolFunctionDefinition {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let schema = serde_json::to_string(&self.parameters).ok();
+ let schema = schema.unwrap_or("None".to_string());
+ write!(f, "Name: {}:\n", self.name)?;
+ write!(f, "Description: {}\n", self.description)?;
+ write!(f, "Parameters: {}", schema)
+ }
+}
+
#[cfg(test)]
mod test {
use super::*;
@@ -213,6 +325,12 @@ mod test {
}
}
+ impl ToolOutput for WeatherView {
+ fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
+ serde_json::to_string(&self.result).unwrap()
+ }
+ }
+
impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery;
type Output = WeatherResult;
@@ -240,7 +358,6 @@ mod test {
}
fn output_view(
- _tool_call_id: String,
_input: Self::Input,
result: Result<Self::Output>,
cx: &mut WindowContext,
@@ -250,10 +367,6 @@ mod test {
WeatherView { result }
})
}
-
- fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
- serde_json::to_string(&output.as_ref().unwrap()).unwrap()
- }
}
#[gpui::test]
@@ -163,6 +163,10 @@ impl ProjectIndex {
self.project.clone()
}
+ pub fn fs(&self) -> Arc<dyn Fs> {
+ self.fs.clone()
+ }
+
fn handle_project_event(
&mut self,
_: Model<Project>,