Centralize project context provided to the assistant (#11471)

Max Brunsfeld and Kyle created

This PR restructures the way that tools and attachments add information
about the current project to a conversation with the assistant. Rather
than each tool call or attachment generating a new tool or system
message containing information about the project, they can all
collectively mutate a new type called a `ProjectContext`, which stores
all of the project data that should be sent to the assistant. That data
is then formatted in a single place, and passed to the assistant in one
system message.

This prevents multiple tools/attachments from including redundant
context.

Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>

Change summary

Cargo.lock                                          |   7 
crates/assistant2/src/assistant2.rs                 |  99 ++--
crates/assistant2/src/attachments.rs                | 206 ++--------
crates/assistant2/src/attachments/active_file.rs    |   1 
crates/assistant2/src/tools/create_buffer.rs        |  39 +
crates/assistant2/src/tools/project_index.rs        | 194 +++------
crates/assistant2/src/ui/active_file_button.rs      |  13 
crates/assistant2/src/ui/composer.rs                |  10 
crates/assistant_tooling/Cargo.toml                 |   8 
crates/assistant_tooling/src/assistant_tooling.rs   |  12 
crates/assistant_tooling/src/attachment_registry.rs | 148 +++++++
crates/assistant_tooling/src/project_context.rs     | 296 +++++++++++++++
crates/assistant_tooling/src/tool.rs                | 111 -----
crates/assistant_tooling/src/tool_registry.rs       | 217 ++++++++--
crates/semantic_index/src/semantic_index.rs         |   4 
15 files changed, 844 insertions(+), 521 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -411,10 +411,17 @@ name = "assistant_tooling"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "collections",
+ "futures 0.3.28",
  "gpui",
+ "project",
  "schemars",
  "serde",
  "serde_json",
+ "settings",
+ "sum_tree",
+ "unindent",
+ "util",
 ]
 
 [[package]]

crates/assistant2/src/assistant2.rs 🔗

@@ -4,10 +4,16 @@ mod completion_provider;
 mod tools;
 pub mod ui;
 
+use crate::{
+    attachments::ActiveEditorAttachmentTool,
+    tools::{CreateBufferTool, ProjectIndexTool},
+    ui::UserOrAssistant,
+};
 use ::ui::{div, prelude::*, Color, ViewContext};
 use anyhow::{Context, Result};
-use assistant_tooling::{ToolFunctionCall, ToolRegistry};
-use attachments::{ActiveEditorAttachmentTool, UserAttachment, UserAttachmentStore};
+use assistant_tooling::{
+    AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
+};
 use client::{proto, Client, UserStore};
 use collections::HashMap;
 use completion_provider::*;
@@ -34,9 +40,6 @@ use workspace::{
 
 pub use assistant_settings::AssistantSettings;
 
-use crate::tools::{CreateBufferTool, ProjectIndexTool};
-use crate::ui::UserOrAssistant;
-
 const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
 
 #[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
@@ -85,10 +88,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
             });
             workspace.register_action(|workspace, _: &DebugProjectIndex, cx| {
                 if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
-                    if let Some(index) = panel.read(cx).chat.read(cx).project_index.clone() {
-                        let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
-                        workspace.add_item_to_center(Box::new(view), cx);
-                    }
+                    let index = panel.read(cx).chat.read(cx).project_index.clone();
+                    let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
+                    workspace.add_item_to_center(Box::new(view), cx);
                 }
             });
         },
@@ -122,10 +124,7 @@ impl AssistantPanel {
 
                 let mut tool_registry = ToolRegistry::new();
                 tool_registry
-                    .register(
-                        ProjectIndexTool::new(project_index.clone(), project.read(cx).fs().clone()),
-                        cx,
-                    )
+                    .register(ProjectIndexTool::new(project_index.clone()), cx)
                     .context("failed to register ProjectIndexTool")
                     .log_err();
                 tool_registry
@@ -136,7 +135,7 @@ impl AssistantPanel {
                     .context("failed to register CreateBufferTool")
                     .log_err();
 
-                let mut attachment_store = UserAttachmentStore::new();
+                let mut attachment_store = AttachmentRegistry::new();
                 attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx));
 
                 Self::new(
@@ -144,7 +143,7 @@ impl AssistantPanel {
                     Arc::new(tool_registry),
                     Arc::new(attachment_store),
                     app_state.user_store.clone(),
-                    Some(project_index),
+                    project_index,
                     workspace,
                     cx,
                 )
@@ -155,9 +154,9 @@ impl AssistantPanel {
     pub fn new(
         language_registry: Arc<LanguageRegistry>,
         tool_registry: Arc<ToolRegistry>,
-        attachment_store: Arc<UserAttachmentStore>,
+        attachment_store: Arc<AttachmentRegistry>,
         user_store: Model<UserStore>,
-        project_index: Option<Model<ProjectIndex>>,
+        project_index: Model<ProjectIndex>,
         workspace: WeakView<Workspace>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
@@ -241,16 +240,16 @@ pub struct AssistantChat {
     list_state: ListState,
     language_registry: Arc<LanguageRegistry>,
     composer_editor: View<Editor>,
-    project_index_button: Option<View<ProjectIndexButton>>,
+    project_index_button: View<ProjectIndexButton>,
     active_file_button: Option<View<ActiveFileButton>>,
     user_store: Model<UserStore>,
     next_message_id: MessageId,
     collapsed_messages: HashMap<MessageId, bool>,
     editing_message: Option<EditingMessage>,
     pending_completion: Option<Task<()>>,
-    attachment_store: Arc<UserAttachmentStore>,
     tool_registry: Arc<ToolRegistry>,
-    project_index: Option<Model<ProjectIndex>>,
+    attachment_registry: Arc<AttachmentRegistry>,
+    project_index: Model<ProjectIndex>,
 }
 
 struct EditingMessage {
@@ -263,9 +262,9 @@ impl AssistantChat {
     fn new(
         language_registry: Arc<LanguageRegistry>,
         tool_registry: Arc<ToolRegistry>,
-        attachment_store: Arc<UserAttachmentStore>,
+        attachment_registry: Arc<AttachmentRegistry>,
         user_store: Model<UserStore>,
-        project_index: Option<Model<ProjectIndex>>,
+        project_index: Model<ProjectIndex>,
         workspace: WeakView<Workspace>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
@@ -281,14 +280,14 @@ impl AssistantChat {
             },
         );
 
-        let project_index_button = project_index.clone().map(|project_index| {
-            cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
+        let project_index_button = cx.new_view(|cx| {
+            ProjectIndexButton::new(project_index.clone(), tool_registry.clone(), cx)
         });
 
         let active_file_button = match workspace.upgrade() {
             Some(workspace) => {
                 Some(cx.new_view(
-                    |cx| ActiveFileButton::new(attachment_store.clone(), workspace, cx), //
+                    |cx| ActiveFileButton::new(attachment_registry.clone(), workspace, cx), //
                 ))
             }
             _ => None,
@@ -313,7 +312,7 @@ impl AssistantChat {
             editing_message: None,
             collapsed_messages: HashMap::default(),
             pending_completion: None,
-            attachment_store,
+            attachment_registry,
             tool_registry,
         }
     }
@@ -395,7 +394,7 @@ impl AssistantChat {
         let mode = *mode;
         self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
             let attachments_task = this.update(&mut cx, |this, cx| {
-                let attachment_store = this.attachment_store.clone();
+                let attachment_store = this.attachment_registry.clone();
                 attachment_store.call_all_attachment_tools(cx)
             });
 
@@ -443,7 +442,7 @@ impl AssistantChat {
         let mut call_count = 0;
         loop {
             let complete = async {
-                let completion = this.update(cx, |this, cx| {
+                let (tool_definitions, model_name, messages) = this.update(cx, |this, cx| {
                     this.push_new_assistant_message(cx);
 
                     let definitions = if call_count < limit
@@ -455,14 +454,22 @@ impl AssistantChat {
                     };
                     call_count += 1;
 
-                    let messages = this.completion_messages(cx);
+                    (
+                        definitions,
+                        this.model.clone(),
+                        this.completion_messages(cx),
+                    )
+                })?;
 
+                let messages = messages.await?;
+
+                let completion = cx.update(|cx| {
                     CompletionProvider::get(cx).complete(
-                        this.model.clone(),
+                        model_name,
                         messages,
                         Vec::new(),
                         1.0,
-                        definitions,
+                        tool_definitions,
                     )
                 });
 
@@ -765,7 +772,12 @@ impl AssistantChat {
         }
     }
 
-    fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
+    fn completion_messages(&self, cx: &mut WindowContext) -> Task<Result<Vec<CompletionMessage>>> {
+        let project_index = self.project_index.read(cx);
+        let project = project_index.project();
+        let fs = project_index.fs();
+
+        let mut project_context = ProjectContext::new(project, fs);
         let mut completion_messages = Vec::new();
 
         for message in &self.messages {
@@ -773,12 +785,11 @@ impl AssistantChat {
                 ChatMessage::User(UserMessage {
                     body, attachments, ..
                 }) => {
-                    completion_messages.extend(
-                        attachments
-                            .into_iter()
-                            .filter_map(|attachment| attachment.message.clone())
-                            .map(|content| CompletionMessage::System { content }),
-                    );
+                    for attachment in attachments {
+                        if let Some(content) = attachment.generate(&mut project_context, cx) {
+                            completion_messages.push(CompletionMessage::System { content });
+                        }
+                    }
 
                     // Show user's message last so that the assistant is grounded in the user's request
                     completion_messages.push(CompletionMessage::User {
@@ -815,7 +826,9 @@ impl AssistantChat {
                     for tool_call in tool_calls {
                         // Every tool call _must_ have a result by ID, otherwise OpenAI will error.
                         let content = match &tool_call.result {
-                            Some(result) => result.format(&tool_call.name),
+                            Some(result) => {
+                                result.generate(&tool_call.name, &mut project_context, cx)
+                            }
                             None => "".to_string(),
                         };
 
@@ -828,7 +841,13 @@ impl AssistantChat {
             }
         }
 
-        completion_messages
+        let system_message = project_context.generate_system_message(cx);
+
+        cx.background_executor().spawn(async move {
+            let content = system_message.await?;
+            completion_messages.insert(0, CompletionMessage::System { content });
+            Ok(completion_messages)
+        })
     }
 }
 

crates/assistant2/src/attachments.rs 🔗

@@ -1,137 +1,18 @@
-use std::{
-    any::TypeId,
-    sync::{
-        atomic::{AtomicBool, Ordering::SeqCst},
-        Arc,
-    },
-};
+pub mod active_file;
 
 use anyhow::{anyhow, Result};
-use collections::HashMap;
+use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
 use editor::Editor;
-use futures::future::join_all;
-use gpui::{AnyView, Render, Task, View, WeakView};
+use gpui::{Render, Task, View, WeakModel, WeakView};
+use language::Buffer;
+use project::ProjectPath;
 use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
-use util::{maybe, ResultExt};
+use util::maybe;
 use workspace::Workspace;
 
-/// A collected attachment from running an attachment tool
-pub struct UserAttachment {
-    pub message: Option<String>,
-    pub view: AnyView,
-}
-
-pub struct UserAttachmentStore {
-    attachment_tools: HashMap<TypeId, DynamicAttachment>,
-}
-
-/// Internal representation of an attachment tool to allow us to treat them dynamically
-struct DynamicAttachment {
-    enabled: AtomicBool,
-    call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
-}
-
-impl UserAttachmentStore {
-    pub fn new() -> Self {
-        Self {
-            attachment_tools: HashMap::default(),
-        }
-    }
-
-    pub fn register<A: AttachmentTool + 'static>(&mut self, attachment: A) {
-        let call = Box::new(move |cx: &mut WindowContext| {
-            let result = attachment.run(cx);
-
-            cx.spawn(move |mut cx| async move {
-                let result: Result<A::Output> = result.await;
-                let message = A::format(&result);
-                let view = cx.update(|cx| A::view(result, cx))?;
-
-                Ok(UserAttachment {
-                    message,
-                    view: view.into(),
-                })
-            })
-        });
-
-        self.attachment_tools.insert(
-            TypeId::of::<A>(),
-            DynamicAttachment {
-                call,
-                enabled: AtomicBool::new(true),
-            },
-        );
-    }
-
-    pub fn set_attachment_tool_enabled<A: AttachmentTool + 'static>(&self, is_enabled: bool) {
-        if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
-            attachment.enabled.store(is_enabled, SeqCst);
-        }
-    }
-
-    pub fn is_attachment_tool_enabled<A: AttachmentTool + 'static>(&self) -> bool {
-        if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
-            attachment.enabled.load(SeqCst)
-        } else {
-            false
-        }
-    }
-
-    pub fn call<A: AttachmentTool + 'static>(
-        &self,
-        cx: &mut WindowContext,
-    ) -> Task<Result<UserAttachment>> {
-        let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) else {
-            return Task::ready(Err(anyhow!("no attachment tool")));
-        };
-
-        (attachment.call)(cx)
-    }
-
-    pub fn call_all_attachment_tools(
-        self: Arc<Self>,
-        cx: &mut WindowContext<'_>,
-    ) -> Task<Result<Vec<UserAttachment>>> {
-        let this = self.clone();
-        cx.spawn(|mut cx| async move {
-            let attachment_tasks = cx.update(|cx| {
-                let mut tasks = Vec::new();
-                for attachment in this
-                    .attachment_tools
-                    .values()
-                    .filter(|attachment| attachment.enabled.load(SeqCst))
-                {
-                    tasks.push((attachment.call)(cx))
-                }
-
-                tasks
-            })?;
-
-            let attachments = join_all(attachment_tasks.into_iter()).await;
-
-            Ok(attachments
-                .into_iter()
-                .filter_map(|attachment| attachment.log_err())
-                .collect())
-        })
-    }
-}
-
-pub trait AttachmentTool {
-    type Output: 'static;
-    type View: Render;
-
-    fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
-    fn format(output: &Result<Self::Output>) -> Option<String>;
-
-    fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
-}
-
 pub struct ActiveEditorAttachment {
-    filename: Arc<str>,
-    language: Arc<str>,
-    text: Arc<str>,
+    buffer: WeakModel<Buffer>,
+    path: Option<ProjectPath>,
 }
 
 pub struct FileAttachmentView {
@@ -142,7 +23,13 @@ impl Render for FileAttachmentView {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         match &self.output {
             Ok(attachment) => {
-                let filename = attachment.filename.clone();
+                let filename: SharedString = attachment
+                    .path
+                    .as_ref()
+                    .and_then(|p| p.path.file_name()?.to_str())
+                    .unwrap_or("Untitled")
+                    .to_string()
+                    .into();
 
                 // todo!(): make the button link to the actual file to open
                 ButtonLike::new("file-attachment")
@@ -152,7 +39,7 @@ impl Render for FileAttachmentView {
                             .bg(cx.theme().colors().editor_background)
                             .rounded_md()
                             .child(ui::Icon::new(IconName::File))
-                            .child(filename.to_string()),
+                            .child(filename.clone()),
                     )
                     .tooltip({
                         move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
@@ -164,6 +51,20 @@ impl Render for FileAttachmentView {
     }
 }
 
+impl ToolOutput for FileAttachmentView {
+    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
+        if let Ok(result) = &self.output {
+            if let Some(path) = &result.path {
+                project.add_file(path.clone());
+                return format!("current file: {}", path.path.display());
+            } else if let Some(buffer) = result.buffer.upgrade() {
+                return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
+            }
+        }
+        String::new()
+    }
+}
+
 pub struct ActiveEditorAttachmentTool {
     workspace: WeakView<Workspace>,
 }
@@ -174,7 +75,7 @@ impl ActiveEditorAttachmentTool {
     }
 }
 
-impl AttachmentTool for ActiveEditorAttachmentTool {
+impl LanguageModelAttachment for ActiveEditorAttachmentTool {
     type Output = ActiveEditorAttachment;
     type View = FileAttachmentView;
 
@@ -191,47 +92,22 @@ impl AttachmentTool for ActiveEditorAttachmentTool {
 
             let buffer = active_buffer.read(cx);
 
-            if let Some(singleton) = buffer.as_singleton() {
-                let singleton = singleton.read(cx);
-
-                let filename = singleton
-                    .file()
-                    .map(|file| file.path().to_string_lossy())
-                    .unwrap_or("Untitled".into());
-
-                let text = singleton.text();
-
-                let language = singleton
-                    .language()
-                    .map(|l| {
-                        let name = l.code_fence_block_name();
-                        name.to_string()
-                    })
-                    .unwrap_or_default();
-
+            if let Some(buffer) = buffer.as_singleton() {
+                let path =
+                    project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
+                        worktree_id: file.worktree_id(cx),
+                        path: file.path.clone(),
+                    });
                 return Ok(ActiveEditorAttachment {
-                    filename: filename.into(),
-                    language: language.into(),
-                    text: text.into(),
+                    buffer: buffer.downgrade(),
+                    path,
                 });
+            } else {
+                Err(anyhow!("no active buffer"))
             }
-
-            Err(anyhow!("no active buffer"))
         }))
     }
 
-    fn format(output: &Result<Self::Output>) -> Option<String> {
-        let output = output.as_ref().ok()?;
-
-        let filename = &output.filename;
-        let language = &output.language;
-        let text = &output.text;
-
-        Some(format!(
-            "User's active file `{filename}`:\n\n```{language}\n{text}```\n\n"
-        ))
-    }
-
     fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
         cx.new_view(|_cx| FileAttachmentView { output })
     }

crates/assistant2/src/tools/create_buffer.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::Result;
-use assistant_tooling::LanguageModelTool;
+use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
 use editor::Editor;
 use gpui::{prelude::*, Model, Task, View, WeakView};
 use project::Project;
@@ -31,11 +31,9 @@ pub struct CreateBufferInput {
     language: String,
 }
 
-pub struct CreateBufferOutput {}
-
 impl LanguageModelTool for CreateBufferTool {
     type Input = CreateBufferInput;
-    type Output = CreateBufferOutput;
+    type Output = ();
     type View = CreateBufferView;
 
     fn name(&self) -> String {
@@ -83,32 +81,39 @@ impl LanguageModelTool for CreateBufferTool {
                     })
                     .log_err();
 
-                Ok(CreateBufferOutput {})
+                Ok(())
             }
         })
     }
 
-    fn format(input: &Self::Input, output: &Result<Self::Output>) -> String {
-        match output {
-            Ok(_) => format!("Created a new {} buffer", input.language),
-            Err(err) => format!("Failed to create buffer: {err:?}"),
-        }
-    }
-
     fn output_view(
-        _tool_call_id: String,
-        _input: Self::Input,
-        _output: Result<Self::Output>,
+        input: Self::Input,
+        output: Result<Self::Output>,
         cx: &mut WindowContext,
     ) -> View<Self::View> {
-        cx.new_view(|_cx| CreateBufferView {})
+        cx.new_view(|_cx| CreateBufferView {
+            language: input.language,
+            output,
+        })
     }
 }
 
-pub struct CreateBufferView {}
+pub struct CreateBufferView {
+    language: String,
+    output: Result<()>,
+}
 
 impl Render for CreateBufferView {
     fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
         div().child("Opening a buffer")
     }
 }
+
+impl ToolOutput for CreateBufferView {
+    fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
+        match &self.output {
+            Ok(_) => format!("Created a new {} buffer", self.language),
+            Err(err) => format!("Failed to create buffer: {err:?}"),
+        }
+    }
+}

crates/assistant2/src/tools/project_index.rs 🔗

@@ -1,25 +1,18 @@
 use anyhow::Result;
-use assistant_tooling::LanguageModelTool;
+use assistant_tooling::{LanguageModelTool, ToolOutput};
+use collections::BTreeMap;
 use gpui::{prelude::*, Model, Task};
-use project::Fs;
+use project::ProjectPath;
 use schemars::JsonSchema;
 use semantic_index::{ProjectIndex, Status};
 use serde::Deserialize;
-use std::{collections::HashSet, sync::Arc};
-
-use ui::{
-    div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
-    WindowContext,
-};
-use util::ResultExt as _;
+use std::{fmt::Write as _, ops::Range};
+use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
 
 const DEFAULT_SEARCH_LIMIT: usize = 20;
 
-#[derive(Clone)]
-pub struct CodebaseExcerpt {
-    path: SharedString,
-    text: SharedString,
-    score: f32,
+pub struct ProjectIndexTool {
+    project_index: Model<ProjectIndex>,
 }
 
 // Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
@@ -40,6 +33,11 @@ pub struct ProjectIndexView {
     expanded_header: bool,
 }
 
+pub struct ProjectIndexOutput {
+    status: Status,
+    excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
+}
+
 impl ProjectIndexView {
     fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
         let element_id = ElementId::Name(nanoid::nanoid!().into());
@@ -71,19 +69,15 @@ impl Render for ProjectIndexView {
             Ok(output) => output,
         };
 
-        let num_files_searched = output.files_searched.len();
+        let file_count = output.excerpts.len();
 
         let header = h_flex()
             .gap_2()
             .child(Icon::new(IconName::File))
             .child(format!(
                 "Read {} {}",
-                num_files_searched,
-                if num_files_searched == 1 {
-                    "file"
-                } else {
-                    "files"
-                }
+                file_count,
+                if file_count == 1 { "file" } else { "files" }
             ));
 
         v_flex().gap_3().child(
@@ -102,36 +96,50 @@ impl Render for ProjectIndexView {
                                 .child(Icon::new(IconName::MagnifyingGlass))
                                 .child(Label::new(format!("`{}`", query)).color(Color::Muted)),
                         )
-                        .child(v_flex().gap_2().children(output.files_searched.iter().map(
-                            |path| {
-                                h_flex()
-                                    .gap_2()
-                                    .child(Icon::new(IconName::File))
-                                    .child(Label::new(path.clone()).color(Color::Muted))
-                            },
-                        ))),
+                        .child(
+                            v_flex()
+                                .gap_2()
+                                .children(output.excerpts.keys().map(|path| {
+                                    h_flex().gap_2().child(Icon::new(IconName::File)).child(
+                                        Label::new(path.path.to_string_lossy().to_string())
+                                            .color(Color::Muted),
+                                    )
+                                })),
+                        ),
                 ),
         )
     }
 }
 
-pub struct ProjectIndexTool {
-    project_index: Model<ProjectIndex>,
-    fs: Arc<dyn Fs>,
-}
+impl ToolOutput for ProjectIndexView {
+    fn generate(
+        &self,
+        context: &mut assistant_tooling::ProjectContext,
+        _: &mut WindowContext,
+    ) -> String {
+        match &self.output {
+            Ok(output) => {
+                let mut body = "found results in the following paths:\n".to_string();
 
-pub struct ProjectIndexOutput {
-    excerpts: Vec<CodebaseExcerpt>,
-    status: Status,
-    files_searched: HashSet<SharedString>,
+                for (project_path, ranges) in &output.excerpts {
+                    context.add_excerpts(project_path.clone(), ranges);
+                    writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
+                }
+
+                if output.status != Status::Idle {
+                    body.push_str("Still indexing. Results may be incomplete.\n");
+                }
+
+                body
+            }
+            Err(err) => format!("Error: {}", err),
+        }
+    }
 }
 
 impl ProjectIndexTool {
-    pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
-        // Listen for project index status and update the ProjectIndexTool directly
-
-        // TODO: setup a better description based on the user's current codebase.
-        Self { project_index, fs }
+    pub fn new(project_index: Model<ProjectIndex>) -> Self {
+        Self { project_index }
     }
 }
 
@@ -151,64 +159,42 @@ impl LanguageModelTool for ProjectIndexTool {
     fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
         let project_index = self.project_index.read(cx);
         let status = project_index.status();
-        let results = project_index.search(
+        let search = project_index.search(
             query.query.clone(),
             query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
             cx,
         );
 
-        let fs = self.fs.clone();
-
-        cx.spawn(|cx| async move {
-            let results = results.await?;
-
-            let excerpts = results.into_iter().map(|result| {
-                let abs_path = result
-                    .worktree
-                    .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
-                let fs = fs.clone();
-
-                async move {
-                    let path = result.path.clone();
-                    let text = fs.load(&abs_path?).await?;
-
-                    let mut start = result.range.start;
-                    let mut end = result.range.end.min(text.len());
-                    while !text.is_char_boundary(start) {
-                        start += 1;
-                    }
-                    while !text.is_char_boundary(end) {
-                        end -= 1;
-                    }
-
-                    anyhow::Ok(CodebaseExcerpt {
-                        path: path.to_string_lossy().to_string().into(),
-                        text: SharedString::from(text[start..end].to_string()),
-                        score: result.score,
-                    })
+        cx.spawn(|mut cx| async move {
+            let search_results = search.await?;
+
+            cx.update(|cx| {
+                let mut output = ProjectIndexOutput {
+                    status,
+                    excerpts: Default::default(),
+                };
+
+                for search_result in search_results {
+                    let path = ProjectPath {
+                        worktree_id: search_result.worktree.read(cx).id(),
+                        path: search_result.path.clone(),
+                    };
+
+                    let excerpts_for_path = output.excerpts.entry(path).or_default();
+                    let ix = match excerpts_for_path
+                        .binary_search_by_key(&search_result.range.start, |r| r.start)
+                    {
+                        Ok(ix) | Err(ix) => ix,
+                    };
+                    excerpts_for_path.insert(ix, search_result.range);
                 }
-            });
-
-            let mut files_searched = HashSet::new();
-            let excerpts = futures::future::join_all(excerpts)
-                .await
-                .into_iter()
-                .filter_map(|result| result.log_err())
-                .inspect(|excerpt| {
-                    files_searched.insert(excerpt.path.clone());
-                })
-                .collect::<Vec<_>>();
-
-            anyhow::Ok(ProjectIndexOutput {
-                excerpts,
-                status,
-                files_searched,
+
+                output
             })
         })
     }
 
     fn output_view(
-        _tool_call_id: String,
         input: Self::Input,
         output: Result<Self::Output>,
         cx: &mut WindowContext,
@@ -220,34 +206,4 @@ impl LanguageModelTool for ProjectIndexTool {
         CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
             .start_slot("Searching code base")
     }
-
-    fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
-        match &output {
-            Ok(output) => {
-                let mut body = "Semantic search results:\n".to_string();
-
-                if output.status != Status::Idle {
-                    body.push_str("Still indexing. Results may be incomplete.\n");
-                }
-
-                if output.excerpts.is_empty() {
-                    body.push_str("No results found");
-                    return body;
-                }
-
-                for excerpt in &output.excerpts {
-                    body.push_str("Excerpt from ");
-                    body.push_str(excerpt.path.as_ref());
-                    body.push_str(", score ");
-                    body.push_str(&excerpt.score.to_string());
-                    body.push_str(":\n");
-                    body.push_str("~~~\n");
-                    body.push_str(excerpt.text.as_ref());
-                    body.push_str("~~~\n");
-                }
-                body
-            }
-            Err(err) => format!("Error: {}", err),
-        }
-    }
 }

crates/assistant2/src/ui/active_file_button.rs 🔗

@@ -1,4 +1,5 @@
-use crate::attachments::{ActiveEditorAttachmentTool, UserAttachmentStore};
+use crate::attachments::ActiveEditorAttachmentTool;
+use assistant_tooling::AttachmentRegistry;
 use editor::Editor;
 use gpui::{prelude::*, Subscription, View};
 use std::sync::Arc;
@@ -13,7 +14,7 @@ enum Status {
 }
 
 pub struct ActiveFileButton {
-    attachment_store: Arc<UserAttachmentStore>,
+    attachment_registry: Arc<AttachmentRegistry>,
     status: Status,
     #[allow(dead_code)]
     workspace_subscription: Subscription,
@@ -21,7 +22,7 @@ pub struct ActiveFileButton {
 
 impl ActiveFileButton {
     pub fn new(
-        attachment_store: Arc<UserAttachmentStore>,
+        attachment_store: Arc<AttachmentRegistry>,
         workspace: View<Workspace>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
@@ -30,14 +31,14 @@ impl ActiveFileButton {
         cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx));
 
         Self {
-            attachment_store,
+            attachment_registry: attachment_store,
             status: Status::NoFile,
             workspace_subscription,
         }
     }
 
     pub fn set_enabled(&mut self, enabled: bool) {
-        self.attachment_store
+        self.attachment_registry
             .set_attachment_tool_enabled::<ActiveEditorAttachmentTool>(enabled);
     }
 
@@ -79,7 +80,7 @@ impl ActiveFileButton {
 impl Render for ActiveFileButton {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let is_enabled = self
-            .attachment_store
+            .attachment_registry
             .is_attachment_tool_enabled::<ActiveEditorAttachmentTool>();
 
         let icon = if is_enabled {

crates/assistant2/src/ui/composer.rs 🔗

@@ -11,7 +11,7 @@ use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, Divider, TextSize, T
 #[derive(IntoElement)]
 pub struct Composer {
     editor: View<Editor>,
-    project_index_button: Option<View<ProjectIndexButton>>,
+    project_index_button: View<ProjectIndexButton>,
     active_file_button: Option<View<ActiveFileButton>>,
     model_selector: AnyElement,
 }
@@ -19,7 +19,7 @@ pub struct Composer {
 impl Composer {
     pub fn new(
         editor: View<Editor>,
-        project_index_button: Option<View<ProjectIndexButton>>,
+        project_index_button: View<ProjectIndexButton>,
         active_file_button: Option<View<ActiveFileButton>>,
         model_selector: AnyElement,
     ) -> Self {
@@ -32,11 +32,7 @@ impl Composer {
     }
 
     fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
-        h_flex().children(
-            self.project_index_button
-                .clone()
-                .map(|view| view.into_any_element()),
-        )
+        h_flex().child(self.project_index_button.clone())
     }
 
     fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {

crates/assistant_tooling/Cargo.toml 🔗

@@ -13,10 +13,18 @@ path = "src/assistant_tooling.rs"
 
 [dependencies]
 anyhow.workspace = true
+collections.workspace = true
+futures.workspace = true
 gpui.workspace = true
+project.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+sum_tree.workspace = true
+util.workspace = true
 
 [dev-dependencies]
 gpui = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }
+unindent.workspace = true

crates/assistant_tooling/src/assistant_tooling.rs 🔗

@@ -1,5 +1,9 @@
-pub mod registry;
-pub mod tool;
+mod attachment_registry;
+mod project_context;
+mod tool_registry;
 
-pub use crate::registry::ToolRegistry;
-pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};
+pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
+pub use project_context::ProjectContext;
+pub use tool_registry::{
+    LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry,
+};

crates/assistant_tooling/src/attachment_registry.rs 🔗

@@ -0,0 +1,148 @@
+use crate::{ProjectContext, ToolOutput};
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use futures::future::join_all;
+use gpui::{AnyView, Render, Task, View, WindowContext};
+use std::{
+    any::TypeId,
+    sync::{
+        atomic::{AtomicBool, Ordering::SeqCst},
+        Arc,
+    },
+};
+use util::ResultExt as _;
+
+pub struct AttachmentRegistry {
+    registered_attachments: HashMap<TypeId, RegisteredAttachment>,
+}
+
+pub trait LanguageModelAttachment {
+    type Output: 'static;
+    type View: Render + ToolOutput;
+
+    fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
+
+    fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
+}
+
+/// A collected attachment from running an attachment tool
+pub struct UserAttachment {
+    pub view: AnyView,
+    generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
+}
+
+/// Internal representation of an attachment tool to allow us to treat them dynamically
+struct RegisteredAttachment {
+    enabled: AtomicBool,
+    call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
+}
+
+impl AttachmentRegistry {
+    pub fn new() -> Self {
+        Self {
+            registered_attachments: HashMap::default(),
+        }
+    }
+
+    pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
+        let call = Box::new(move |cx: &mut WindowContext| {
+            let result = attachment.run(cx);
+
+            cx.spawn(move |mut cx| async move {
+                let result: Result<A::Output> = result.await;
+                let view = cx.update(|cx| A::view(result, cx))?;
+
+                Ok(UserAttachment {
+                    view: view.into(),
+                    generate_fn: generate::<A>,
+                })
+            })
+        });
+
+        self.registered_attachments.insert(
+            TypeId::of::<A>(),
+            RegisteredAttachment {
+                call,
+                enabled: AtomicBool::new(true),
+            },
+        );
+        return;
+
+        fn generate<T: LanguageModelAttachment>(
+            view: AnyView,
+            project: &mut ProjectContext,
+            cx: &mut WindowContext,
+        ) -> String {
+            view.downcast::<T::View>()
+                .unwrap()
+                .update(cx, |view, cx| T::View::generate(view, project, cx))
+        }
+    }
+
+    pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
+        &self,
+        is_enabled: bool,
+    ) {
+        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
+            attachment.enabled.store(is_enabled, SeqCst);
+        }
+    }
+
+    pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
+        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
+            attachment.enabled.load(SeqCst)
+        } else {
+            false
+        }
+    }
+
+    pub fn call<A: LanguageModelAttachment + 'static>(
+        &self,
+        cx: &mut WindowContext,
+    ) -> Task<Result<UserAttachment>> {
+        let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
+            return Task::ready(Err(anyhow!("no attachment tool")));
+        };
+
+        (attachment.call)(cx)
+    }
+
+    pub fn call_all_attachment_tools(
+        self: Arc<Self>,
+        cx: &mut WindowContext<'_>,
+    ) -> Task<Result<Vec<UserAttachment>>> {
+        let this = self.clone();
+        cx.spawn(|mut cx| async move {
+            let attachment_tasks = cx.update(|cx| {
+                let mut tasks = Vec::new();
+                for attachment in this
+                    .registered_attachments
+                    .values()
+                    .filter(|attachment| attachment.enabled.load(SeqCst))
+                {
+                    tasks.push((attachment.call)(cx))
+                }
+
+                tasks
+            })?;
+
+            let attachments = join_all(attachment_tasks.into_iter()).await;
+
+            Ok(attachments
+                .into_iter()
+                .filter_map(|attachment| attachment.log_err())
+                .collect())
+        })
+    }
+}
+
+impl UserAttachment {
+    pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
+        let result = (self.generate_fn)(self.view.clone(), output, cx);
+        if result.is_empty() {
+            None
+        } else {
+            Some(result)
+        }
+    }
+}

crates/assistant_tooling/src/project_context.rs 🔗

@@ -0,0 +1,296 @@
+use anyhow::{anyhow, Result};
+use gpui::{AppContext, Model, Task, WeakModel};
+use project::{Fs, Project, ProjectPath, Worktree};
+use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
+use sum_tree::TreeMap;
+
+pub struct ProjectContext {
+    files: TreeMap<ProjectPath, PathState>,
+    project: WeakModel<Project>,
+    fs: Arc<dyn Fs>,
+}
+
+#[derive(Debug, Clone)]
+enum PathState {
+    PathOnly,
+    EntireFile,
+    Excerpts { ranges: Vec<Range<usize>> },
+}
+
+impl ProjectContext {
+    pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
+        Self {
+            files: TreeMap::default(),
+            fs,
+            project,
+        }
+    }
+
+    pub fn add_path(&mut self, project_path: ProjectPath) {
+        if self.files.get(&project_path).is_none() {
+            self.files.insert(project_path, PathState::PathOnly);
+        }
+    }
+
+    pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
+        let previous_state = self
+            .files
+            .get(&project_path)
+            .unwrap_or(&PathState::PathOnly);
+
+        let mut ranges = match previous_state {
+            PathState::EntireFile => return,
+            PathState::PathOnly => Vec::new(),
+            PathState::Excerpts { ranges } => ranges.to_vec(),
+        };
+
+        for new_range in new_ranges {
+            let ix = ranges.binary_search_by(|probe| {
+                if probe.end < new_range.start {
+                    Ordering::Less
+                } else if probe.start > new_range.end {
+                    Ordering::Greater
+                } else {
+                    Ordering::Equal
+                }
+            });
+
+            match ix {
+                Ok(mut ix) => {
+                    let existing = &mut ranges[ix];
+                    existing.start = existing.start.min(new_range.start);
+                    existing.end = existing.end.max(new_range.end);
+                    while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
+                        ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
+                        ranges.remove(ix + 1);
+                    }
+                    while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
+                        ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
+                        ranges.remove(ix - 1);
+                        ix -= 1;
+                    }
+                }
+                Err(ix) => {
+                    ranges.insert(ix, new_range.clone());
+                }
+            }
+        }
+
+        self.files
+            .insert(project_path, PathState::Excerpts { ranges });
+    }
+
+    pub fn add_file(&mut self, project_path: ProjectPath) {
+        self.files.insert(project_path, PathState::EntireFile);
+    }
+
+    pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
+        let project = self
+            .project
+            .upgrade()
+            .ok_or_else(|| anyhow!("project dropped"));
+        let files = self.files.clone();
+        let fs = self.fs.clone();
+        cx.spawn(|cx| async move {
+            let project = project?;
+            let mut result = "project structure:\n".to_string();
+
+            let mut last_worktree: Option<Model<Worktree>> = None;
+            for (project_path, path_state) in files.iter() {
+                if let Some(worktree) = &last_worktree {
+                    if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
+                        last_worktree = None;
+                    }
+                }
+
+                let worktree;
+                if let Some(last_worktree) = &last_worktree {
+                    worktree = last_worktree.clone();
+                } else if let Some(tree) = project.read_with(&cx, |project, cx| {
+                    project.worktree_for_id(project_path.worktree_id, cx)
+                })? {
+                    worktree = tree;
+                    last_worktree = Some(worktree.clone());
+                    let worktree_name =
+                        worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
+                    writeln!(&mut result, "# {}", worktree_name).unwrap();
+                } else {
+                    continue;
+                }
+
+                let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
+                let path = &project_path.path;
+                writeln!(&mut result, "## {}", path.display()).unwrap();
+
+                match path_state {
+                    PathState::PathOnly => {}
+                    PathState::EntireFile => {
+                        let text = fs.load(&worktree_abs_path.join(&path)).await?;
+                        writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
+                    }
+                    PathState::Excerpts { ranges } => {
+                        let text = fs.load(&worktree_abs_path.join(&path)).await?;
+
+                        writeln!(&mut result, "~~~").unwrap();
+
+                        // Assumption: ranges are in order, not overlapping
+                        let mut prev_range_end = 0;
+                        for range in ranges {
+                            if range.start > prev_range_end {
+                                writeln!(&mut result, "...").unwrap();
+                                prev_range_end = range.end;
+                            }
+
+                            let mut start = range.start;
+                            let mut end = range.end.min(text.len());
+                            while !text.is_char_boundary(start) {
+                                start += 1;
+                            }
+                            while !text.is_char_boundary(end) {
+                                end -= 1;
+                            }
+                            result.push_str(&text[start..end]);
+                            if !result.ends_with('\n') {
+                                result.push('\n');
+                            }
+                        }
+
+                        if prev_range_end < text.len() {
+                            writeln!(&mut result, "...").unwrap();
+                        }
+
+                        writeln!(&mut result, "~~~").unwrap();
+                    }
+                }
+            }
+            Ok(result)
+        })
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::path::Path;
+
+    use super::*;
+    use gpui::TestAppContext;
+    use project::FakeFs;
+    use serde_json::json;
+    use settings::SettingsStore;
+
+    use unindent::Unindent as _;
+
+    #[gpui::test]
+    async fn test_system_message_generation(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let file_3_contents = r#"
+            fn test1() {}
+            fn test2() {}
+            fn test3() {}
+        "#
+        .unindent();
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/code",
+            json!({
+                "root1": {
+                    "lib": {
+                        "file1.rs": "mod example;",
+                        "file2.rs": "",
+                    },
+                    "test": {
+                        "file3.rs": file_3_contents,
+                    }
+                },
+                "root2": {
+                    "src": {
+                        "main.rs": ""
+                    }
+                }
+            }),
+        )
+        .await;
+
+        let project = Project::test(
+            fs.clone(),
+            ["/code/root1".as_ref(), "/code/root2".as_ref()],
+            cx,
+        )
+        .await;
+
+        let worktree_ids = project.read_with(cx, |project, cx| {
+            project
+                .worktrees()
+                .map(|worktree| worktree.read(cx).id())
+                .collect::<Vec<_>>()
+        });
+
+        let mut ax = ProjectContext::new(project.downgrade(), fs);
+
+        ax.add_file(ProjectPath {
+            worktree_id: worktree_ids[0],
+            path: Path::new("lib/file1.rs").into(),
+        });
+
+        let message = cx
+            .update(|cx| ax.generate_system_message(cx))
+            .await
+            .unwrap();
+        assert_eq!(
+            r#"
+            project structure:
+            # root1
+            ## lib/file1.rs
+            ~~~
+            mod example;
+            ~~~
+            "#
+            .unindent(),
+            message
+        );
+
+        ax.add_excerpts(
+            ProjectPath {
+                worktree_id: worktree_ids[0],
+                path: Path::new("test/file3.rs").into(),
+            },
+            &[
+                file_3_contents.find("fn test2").unwrap()
+                    ..file_3_contents.find("fn test3").unwrap(),
+            ],
+        );
+
+        let message = cx
+            .update(|cx| ax.generate_system_message(cx))
+            .await
+            .unwrap();
+        assert_eq!(
+            r#"
+            project structure:
+            # root1
+            ## lib/file1.rs
+            ~~~
+            mod example;
+            ~~~
+            ## test/file3.rs
+            ~~~
+            ...
+            fn test2() {}
+            ...
+            ~~~
+            "#
+            .unindent(),
+            message
+        );
+    }
+
+    fn init_test(cx: &mut TestAppContext) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            Project::init_settings(cx);
+        });
+    }
+}

crates/assistant_tooling/src/tool.rs 🔗

@@ -1,111 +0,0 @@
-use anyhow::Result;
-use gpui::{div, AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
-use schemars::{schema::RootSchema, schema_for, JsonSchema};
-use serde::Deserialize;
-use std::fmt::Display;
-
-#[derive(Default, Deserialize)]
-pub struct ToolFunctionCall {
-    pub id: String,
-    pub name: String,
-    pub arguments: String,
-    #[serde(skip)]
-    pub result: Option<ToolFunctionCallResult>,
-}
-
-pub enum ToolFunctionCallResult {
-    NoSuchTool,
-    ParsingFailed,
-    Finished { for_model: String, view: AnyView },
-}
-
-impl ToolFunctionCallResult {
-    pub fn format(&self, name: &String) -> String {
-        match self {
-            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
-            ToolFunctionCallResult::ParsingFailed => {
-                format!("Unable to parse arguments for {name}")
-            }
-            ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
-        }
-    }
-
-    pub fn into_any_element(&self, name: &String) -> AnyElement {
-        match self {
-            ToolFunctionCallResult::NoSuchTool => {
-                format!("Language Model attempted to call {name}").into_any_element()
-            }
-            ToolFunctionCallResult::ParsingFailed => {
-                format!("Language Model called {name} with bad arguments").into_any_element()
-            }
-            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
-        }
-    }
-}
-
-#[derive(Clone)]
-pub struct ToolFunctionDefinition {
-    pub name: String,
-    pub description: String,
-    pub parameters: RootSchema,
-}
-
-impl Display for ToolFunctionDefinition {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        let schema = serde_json::to_string(&self.parameters).ok();
-        let schema = schema.unwrap_or("None".to_string());
-        write!(f, "Name: {}:\n", self.name)?;
-        write!(f, "Description: {}\n", self.description)?;
-        write!(f, "Parameters: {}", schema)
-    }
-}
-
-pub trait LanguageModelTool {
-    /// The input type that will be passed in to `execute` when the tool is called
-    /// by the language model.
-    type Input: for<'de> Deserialize<'de> + JsonSchema;
-
-    /// The output returned by executing the tool.
-    type Output: 'static;
-
-    type View: Render;
-
-    /// Returns the name of the tool.
-    ///
-    /// This name is exposed to the language model to allow the model to pick
-    /// which tools to use. As this name is used to identify the tool within a
-    /// tool registry, it should be unique.
-    fn name(&self) -> String;
-
-    /// Returns the description of the tool.
-    ///
-    /// This can be used to _prompt_ the model as to what the tool does.
-    fn description(&self) -> String;
-
-    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
-    fn definition(&self) -> ToolFunctionDefinition {
-        let root_schema = schema_for!(Self::Input);
-
-        ToolFunctionDefinition {
-            name: self.name(),
-            description: self.description(),
-            parameters: root_schema,
-        }
-    }
-
-    /// Executes the tool with the given input.
-    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
-    fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
-
-    fn output_view(
-        tool_call_id: String,
-        input: Self::Input,
-        output: Result<Self::Output>,
-        cx: &mut WindowContext,
-    ) -> View<Self::View>;
-
-    fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
-        div()
-    }
-}

crates/assistant_tooling/src/registry.rs → crates/assistant_tooling/src/tool_registry.rs 🔗

@@ -1,54 +1,115 @@
 use anyhow::{anyhow, Result};
-use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext};
+use gpui::{
+    div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
+};
+use schemars::{schema::RootSchema, schema_for, JsonSchema};
+use serde::Deserialize;
 use std::{
     any::TypeId,
     collections::HashMap,
+    fmt::Display,
     sync::atomic::{AtomicBool, Ordering::SeqCst},
 };
 
-use crate::tool::{
-    LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
-};
+use crate::ProjectContext;
 
-// Internal Tool representation for the registry
-pub struct Tool {
-    enabled: AtomicBool,
-    type_id: TypeId,
-    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
-    render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
-    definition: ToolFunctionDefinition,
+pub struct ToolRegistry {
+    registered_tools: HashMap<String, RegisteredTool>,
 }
 
-impl Tool {
-    fn new(
-        type_id: TypeId,
-        call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
-        render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
-        definition: ToolFunctionDefinition,
-    ) -> Self {
-        Self {
-            enabled: AtomicBool::new(true),
-            type_id,
-            call,
-            render_running,
-            definition,
+#[derive(Default, Deserialize)]
+pub struct ToolFunctionCall {
+    pub id: String,
+    pub name: String,
+    pub arguments: String,
+    #[serde(skip)]
+    pub result: Option<ToolFunctionCallResult>,
+}
+
+pub enum ToolFunctionCallResult {
+    NoSuchTool,
+    ParsingFailed,
+    Finished {
+        view: AnyView,
+        generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
+    },
+}
+
+#[derive(Clone)]
+pub struct ToolFunctionDefinition {
+    pub name: String,
+    pub description: String,
+    pub parameters: RootSchema,
+}
+
+pub trait LanguageModelTool {
+    /// The input type that will be passed in to `execute` when the tool is called
+    /// by the language model.
+    type Input: for<'de> Deserialize<'de> + JsonSchema;
+
+    /// The output returned by executing the tool.
+    type Output: 'static;
+
+    type View: Render + ToolOutput;
+
+    /// Returns the name of the tool.
+    ///
+    /// This name is exposed to the language model to allow the model to pick
+    /// which tools to use. As this name is used to identify the tool within a
+    /// tool registry, it should be unique.
+    fn name(&self) -> String;
+
+    /// Returns the description of the tool.
+    ///
+    /// This can be used to _prompt_ the model as to what the tool does.
+    fn description(&self) -> String;
+
+    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
+    fn definition(&self) -> ToolFunctionDefinition {
+        let root_schema = schema_for!(Self::Input);
+
+        ToolFunctionDefinition {
+            name: self.name(),
+            description: self.description(),
+            parameters: root_schema,
         }
     }
+
+    /// Executes the tool with the given input.
+    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
+
+    fn output_view(
+        input: Self::Input,
+        output: Result<Self::Output>,
+        cx: &mut WindowContext,
+    ) -> View<Self::View>;
+
+    fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
+        div()
+    }
 }
 
-pub struct ToolRegistry {
-    tools: HashMap<String, Tool>,
+pub trait ToolOutput: Sized {
+    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
+}
+
+struct RegisteredTool {
+    enabled: AtomicBool,
+    type_id: TypeId,
+    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+    render_running: fn(&mut WindowContext) -> gpui::AnyElement,
+    definition: ToolFunctionDefinition,
 }
 
 impl ToolRegistry {
     pub fn new() -> Self {
         Self {
-            tools: HashMap::new(),
+            registered_tools: HashMap::new(),
         }
     }
 
     pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
-        for tool in self.tools.values() {
+        for tool in self.registered_tools.values() {
             if tool.type_id == TypeId::of::<T>() {
                 tool.enabled.store(is_enabled, SeqCst);
                 return;
@@ -57,7 +118,7 @@ impl ToolRegistry {
     }
 
     pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
-        for tool in self.tools.values() {
+        for tool in self.registered_tools.values() {
             if tool.type_id == TypeId::of::<T>() {
                 return tool.enabled.load(SeqCst);
             }
@@ -66,7 +127,7 @@ impl ToolRegistry {
     }
 
     pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
-        self.tools
+        self.registered_tools
             .values()
             .filter(|tool| tool.enabled.load(SeqCst))
             .map(|tool| tool.definition.clone())
@@ -84,7 +145,7 @@ impl ToolRegistry {
                 .child(result.into_any_element(&tool_call.name))
                 .into_any_element(),
             None => self
-                .tools
+                .registered_tools
                 .get(&tool_call.name)
                 .map(|tool| (tool.render_running)(cx))
                 .unwrap_or_else(|| div().into_any_element()),
@@ -96,13 +157,12 @@ impl ToolRegistry {
         tool: T,
         _cx: &mut WindowContext,
     ) -> Result<()> {
-        let definition = tool.definition();
-
         let name = tool.name();
-
-        let registered_tool = Tool::new(
-            TypeId::of::<T>(),
-            Box::new(
+        let registered_tool = RegisteredTool {
+            type_id: TypeId::of::<T>(),
+            definition: tool.definition(),
+            enabled: AtomicBool::new(true),
+            call: Box::new(
                 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
                     let name = tool_call.name.clone();
                     let arguments = tool_call.arguments.clone();
@@ -121,8 +181,7 @@ impl ToolRegistry {
 
                     cx.spawn(move |mut cx| async move {
                         let result: Result<T::Output> = result.await;
-                        let for_model = T::format(&input, &result);
-                        let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
+                        let view = cx.update(|cx| T::output_view(input, result, cx))?;
 
                         Ok(ToolFunctionCall {
                             id,
@@ -130,23 +189,35 @@ impl ToolRegistry {
                             arguments,
                             result: Some(ToolFunctionCallResult::Finished {
                                 view: view.into(),
-                                for_model,
+                                generate_fn: generate::<T>,
                             }),
                         })
                     })
                 },
             ),
-            Box::new(|cx| T::render_running(cx).into_any_element()),
-            definition,
-        );
-
-        let previous = self.tools.insert(name.clone(), registered_tool);
+            render_running: render_running::<T>,
+        };
 
+        let previous = self.registered_tools.insert(name.clone(), registered_tool);
         if previous.is_some() {
             return Err(anyhow!("already registered a tool with name {}", name));
         }
 
-        Ok(())
+        return Ok(());
+
+        fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
+            T::render_running(cx).into_any_element()
+        }
+
+        fn generate<T: LanguageModelTool>(
+            view: AnyView,
+            project: &mut ProjectContext,
+            cx: &mut WindowContext,
+        ) -> String {
+            view.downcast::<T::View>()
+                .unwrap()
+                .update(cx, |view, cx| T::View::generate(view, project, cx))
+        }
     }
 
     /// Task yields an error if the window for the given WindowContext is closed before the task completes.
@@ -159,7 +230,7 @@ impl ToolRegistry {
         let arguments = tool_call.arguments.clone();
         let id = tool_call.id.clone();
 
-        let tool = match self.tools.get(&name) {
+        let tool = match self.registered_tools.get(&name) {
             Some(tool) => tool,
             None => {
                 let name = name.clone();
@@ -176,6 +247,47 @@ impl ToolRegistry {
     }
 }
 
+impl ToolFunctionCallResult {
+    pub fn generate(
+        &self,
+        name: &String,
+        project: &mut ProjectContext,
+        cx: &mut WindowContext,
+    ) -> String {
+        match self {
+            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
+            ToolFunctionCallResult::ParsingFailed => {
+                format!("Unable to parse arguments for {name}")
+            }
+            ToolFunctionCallResult::Finished { generate_fn, view } => {
+                (generate_fn)(view.clone(), project, cx)
+            }
+        }
+    }
+
+    fn into_any_element(&self, name: &String) -> AnyElement {
+        match self {
+            ToolFunctionCallResult::NoSuchTool => {
+                format!("Language Model attempted to call {name}").into_any_element()
+            }
+            ToolFunctionCallResult::ParsingFailed => {
+                format!("Language Model called {name} with bad arguments").into_any_element()
+            }
+            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
+        }
+    }
+}
+
+impl Display for ToolFunctionDefinition {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        let schema = serde_json::to_string(&self.parameters).ok();
+        let schema = schema.unwrap_or("None".to_string());
+        write!(f, "Name: {}:\n", self.name)?;
+        write!(f, "Description: {}\n", self.description)?;
+        write!(f, "Parameters: {}", schema)
+    }
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
@@ -213,6 +325,12 @@ mod test {
         }
     }
 
+    impl ToolOutput for WeatherView {
+        fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
+            serde_json::to_string(&self.result).unwrap()
+        }
+    }
+
     impl LanguageModelTool for WeatherTool {
         type Input = WeatherQuery;
         type Output = WeatherResult;
@@ -240,7 +358,6 @@ mod test {
         }
 
         fn output_view(
-            _tool_call_id: String,
             _input: Self::Input,
             result: Result<Self::Output>,
             cx: &mut WindowContext,
@@ -250,10 +367,6 @@ mod test {
                 WeatherView { result }
             })
         }
-
-        fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
-            serde_json::to_string(&output.as_ref().unwrap()).unwrap()
-        }
     }
 
     #[gpui::test]