agent2: Port rules UI (#36429)

Bennet Bo Fenner created

Release Notes:

- N/A

Change summary

crates/agent2/src/agent.rs                |  19 +-
crates/agent2/src/tests/mod.rs            |  10 
crates/agent2/src/thread.rs               |  20 +-
crates/agent2/src/tools/edit_file_tool.rs |  20 +-
crates/agent_ui/src/acp/thread_view.rs    | 160 ++++++++++++++++++++++++
5 files changed, 197 insertions(+), 32 deletions(-)

Detailed changes

crates/agent2/src/agent.rs 🔗

@@ -22,7 +22,6 @@ use prompt_store::{
 };
 use settings::update_settings_file;
 use std::any::Any;
-use std::cell::RefCell;
 use std::collections::HashMap;
 use std::path::Path;
 use std::rc::Rc;
@@ -156,7 +155,7 @@ pub struct NativeAgent {
     /// Session ID -> Session mapping
     sessions: HashMap<acp::SessionId, Session>,
     /// Shared project context for all threads
-    project_context: Rc<RefCell<ProjectContext>>,
+    project_context: Entity<ProjectContext>,
     project_context_needs_refresh: watch::Sender<()>,
     _maintain_project_context: Task<Result<()>>,
     context_server_registry: Entity<ContextServerRegistry>,
@@ -200,7 +199,7 @@ impl NativeAgent {
                 watch::channel(());
             Self {
                 sessions: HashMap::new(),
-                project_context: Rc::new(RefCell::new(project_context)),
+                project_context: cx.new(|_| project_context),
                 project_context_needs_refresh: project_context_needs_refresh_tx,
                 _maintain_project_context: cx.spawn(async move |this, cx| {
                     Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
@@ -233,7 +232,9 @@ impl NativeAgent {
                     Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
                 })?
                 .await;
-            this.update(cx, |this, _| this.project_context.replace(project_context))?;
+            this.update(cx, |this, cx| {
+                this.project_context = cx.new(|_| project_context);
+            })?;
         }
 
         Ok(())
@@ -872,8 +873,8 @@ mod tests {
         )
         .await
         .unwrap();
-        agent.read_with(cx, |agent, _| {
-            assert_eq!(agent.project_context.borrow().worktrees, vec![])
+        agent.read_with(cx, |agent, cx| {
+            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
         });
 
         let worktree = project
@@ -881,9 +882,9 @@ mod tests {
             .await
             .unwrap();
         cx.run_until_parked();
-        agent.read_with(cx, |agent, _| {
+        agent.read_with(cx, |agent, cx| {
             assert_eq!(
-                agent.project_context.borrow().worktrees,
+                agent.project_context.read(cx).worktrees,
                 vec![WorktreeContext {
                     root_name: "a".into(),
                     abs_path: Path::new("/a").into(),
@@ -898,7 +899,7 @@ mod tests {
         agent.read_with(cx, |agent, cx| {
             let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
             assert_eq!(
-                agent.project_context.borrow().worktrees,
+                agent.project_context.read(cx).worktrees,
                 vec![WorktreeContext {
                     root_name: "a".into(),
                     abs_path: Path::new("/a").into(),

crates/agent2/src/tests/mod.rs 🔗

@@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
 use serde_json::json;
 use settings::SettingsStore;
 use smol::stream::StreamExt;
-use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
+use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 use util::path;
 
 mod test_tools;
@@ -101,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
     } = setup(cx, TestModel::Fake).await;
     let fake_model = model.as_fake();
 
-    project_context.borrow_mut().shell = "test-shell".into();
+    project_context.update(cx, |project_context, _cx| {
+        project_context.shell = "test-shell".into()
+    });
     thread.update(cx, |thread, _| thread.add_tool(EchoTool));
     thread
         .update(cx, |thread, cx| {
@@ -1447,7 +1449,7 @@ fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopR
 struct ThreadTest {
     model: Arc<dyn LanguageModel>,
     thread: Entity<Thread>,
-    project_context: Rc<RefCell<ProjectContext>>,
+    project_context: Entity<ProjectContext>,
     fs: Arc<FakeFs>,
 }
 
@@ -1543,7 +1545,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
         })
         .await;
 
-    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
+    let project_context = cx.new(|_cx| ProjectContext::default());
     let context_server_registry =
         cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
     let action_log = cx.new(|_| ActionLog::new(project.clone()));

crates/agent2/src/thread.rs 🔗

@@ -25,7 +25,7 @@ use schemars::{JsonSchema, Schema};
 use serde::{Deserialize, Serialize};
 use settings::{Settings, update_settings_file};
 use smol::stream::StreamExt;
-use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
+use std::{collections::BTreeMap, path::Path, sync::Arc};
 use std::{fmt::Write, ops::Range};
 use util::{ResultExt, markdown::MarkdownCodeBlock};
 use uuid::Uuid;
@@ -479,7 +479,7 @@ pub struct Thread {
     tool_use_limit_reached: bool,
     context_server_registry: Entity<ContextServerRegistry>,
     profile_id: AgentProfileId,
-    project_context: Rc<RefCell<ProjectContext>>,
+    project_context: Entity<ProjectContext>,
     templates: Arc<Templates>,
     model: Option<Arc<dyn LanguageModel>>,
     project: Entity<Project>,
@@ -489,7 +489,7 @@ pub struct Thread {
 impl Thread {
     pub fn new(
         project: Entity<Project>,
-        project_context: Rc<RefCell<ProjectContext>>,
+        project_context: Entity<ProjectContext>,
         context_server_registry: Entity<ContextServerRegistry>,
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
@@ -520,6 +520,10 @@ impl Thread {
         &self.project
     }
 
+    pub fn project_context(&self) -> &Entity<ProjectContext> {
+        &self.project_context
+    }
+
     pub fn action_log(&self) -> &Entity<ActionLog> {
         &self.action_log
     }
@@ -750,10 +754,10 @@ impl Thread {
         Ok(events_rx)
     }
 
-    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
+    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
         log::debug!("Building system message");
         let prompt = SystemPromptTemplate {
-            project: &self.project_context.borrow(),
+            project: &self.project_context.read(cx),
             available_tools: self.tools.keys().cloned().collect(),
         }
         .render(&self.templates)
@@ -1030,7 +1034,7 @@ impl Thread {
         log::debug!("Completion intent: {:?}", completion_intent);
         log::debug!("Completion mode: {:?}", self.completion_mode);
 
-        let messages = self.build_request_messages();
+        let messages = self.build_request_messages(cx);
         log::info!("Request will include {} messages", messages.len());
 
         let tools = if let Some(tools) = self.tools(cx).log_err() {
@@ -1101,12 +1105,12 @@ impl Thread {
             )))
     }
 
-    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
+    fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
         log::trace!(
             "Building request messages from {} thread messages",
             self.messages.len()
         );
-        let mut messages = vec![self.build_system_message()];
+        let mut messages = vec![self.build_system_message(cx)];
         for message in &self.messages {
             match message {
                 Message::User(message) => messages.push(message.to_request()),

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -503,9 +503,9 @@ mod tests {
     use fs::Fs;
     use gpui::{TestAppContext, UpdateGlobal};
     use language_model::fake_provider::FakeLanguageModel;
+    use prompt_store::ProjectContext;
     use serde_json::json;
     use settings::SettingsStore;
-    use std::rc::Rc;
     use util::path;
 
     #[gpui::test]
@@ -522,7 +522,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project,
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry,
                 action_log,
                 Templates::new(),
@@ -719,7 +719,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project,
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry,
                 action_log.clone(),
                 Templates::new(),
@@ -855,7 +855,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project,
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry,
                 action_log.clone(),
                 Templates::new(),
@@ -981,7 +981,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project,
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry,
                 action_log.clone(),
                 Templates::new(),
@@ -1118,7 +1118,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project,
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry,
                 action_log.clone(),
                 Templates::new(),
@@ -1228,7 +1228,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry.clone(),
                 action_log.clone(),
                 Templates::new(),
@@ -1309,7 +1309,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry.clone(),
                 action_log.clone(),
                 Templates::new(),
@@ -1393,7 +1393,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry.clone(),
                 action_log.clone(),
                 Templates::new(),
@@ -1474,7 +1474,7 @@ mod tests {
         let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
-                Rc::default(),
+                cx.new(|_cx| ProjectContext::default()),
                 context_server_registry,
                 action_log.clone(),
                 Templates::new(),

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -30,7 +30,7 @@ use language::Buffer;
 
 use language_model::LanguageModelRegistry;
 use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
-use project::Project;
+use project::{Project, ProjectEntryId};
 use prompt_store::PromptId;
 use rope::Point;
 use settings::{Settings as _, SettingsStore};
@@ -703,6 +703,38 @@ impl AcpThreadView {
         })
     }
 
+    fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
+        let Some(thread) = self.as_native_thread(cx) else {
+            return;
+        };
+        let project_context = thread.read(cx).project_context().read(cx);
+
+        let project_entry_ids = project_context
+            .worktrees
+            .iter()
+            .flat_map(|worktree| worktree.rules_file.as_ref())
+            .map(|rules_file| ProjectEntryId::from_usize(rules_file.project_entry_id))
+            .collect::<Vec<_>>();
+
+        self.workspace
+            .update(cx, move |workspace, cx| {
+                // TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules
+                // files clear. For example, if rules file 1 is already open but rules file 2 is not,
+                // this would open and focus rules file 2 in a tab that is not next to rules file 1.
+                let project = workspace.project().read(cx);
+                let project_paths = project_entry_ids
+                    .into_iter()
+                    .flat_map(|entry_id| project.path_for_entry(entry_id, cx))
+                    .collect::<Vec<_>>();
+                for project_path in project_paths {
+                    workspace
+                        .open_path(project_path, None, true, window, cx)
+                        .detach_and_log_err(cx);
+                }
+            })
+            .ok();
+    }
+
     fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context<Self>) {
         self.thread_error = Some(ThreadError::from_err(error));
         cx.notify();
@@ -858,6 +890,12 @@ impl AcpThreadView {
                 let editor_focus = editor.focus_handle(cx).is_focused(window);
                 let focus_border = cx.theme().colors().border_focused;
 
+                let rules_item = if entry_ix == 0 {
+                    self.render_rules_item(cx)
+                } else {
+                    None
+                };
+
                 div()
                     .id(("user_message", entry_ix))
                     .py_4()
@@ -874,6 +912,7 @@ impl AcpThreadView {
                                 }))
                         })
                     }))
+                    .children(rules_item)
                     .child(
                         div()
                             .relative()
@@ -1862,6 +1901,125 @@ impl AcpThreadView {
             .into_any_element()
     }
 
+    fn render_rules_item(&self, cx: &Context<Self>) -> Option<AnyElement> {
+        let project_context = self
+            .as_native_thread(cx)?
+            .read(cx)
+            .project_context()
+            .read(cx);
+
+        let user_rules_text = if project_context.user_rules.is_empty() {
+            None
+        } else if project_context.user_rules.len() == 1 {
+            let user_rules = &project_context.user_rules[0];
+
+            match user_rules.title.as_ref() {
+                Some(title) => Some(format!("Using \"{title}\" user rule")),
+                None => Some("Using user rule".into()),
+            }
+        } else {
+            Some(format!(
+                "Using {} user rules",
+                project_context.user_rules.len()
+            ))
+        };
+
+        let first_user_rules_id = project_context
+            .user_rules
+            .first()
+            .map(|user_rules| user_rules.uuid.0);
+
+        let rules_files = project_context
+            .worktrees
+            .iter()
+            .filter_map(|worktree| worktree.rules_file.as_ref())
+            .collect::<Vec<_>>();
+
+        let rules_file_text = match rules_files.as_slice() {
+            &[] => None,
+            &[rules_file] => Some(format!(
+                "Using project {:?} file",
+                rules_file.path_in_worktree
+            )),
+            rules_files => Some(format!("Using {} project rules files", rules_files.len())),
+        };
+
+        if user_rules_text.is_none() && rules_file_text.is_none() {
+            return None;
+        }
+
+        Some(
+            v_flex()
+                .pt_2()
+                .px_2p5()
+                .gap_1()
+                .when_some(user_rules_text, |parent, user_rules_text| {
+                    parent.child(
+                        h_flex()
+                            .w_full()
+                            .child(
+                                Icon::new(IconName::Reader)
+                                    .size(IconSize::XSmall)
+                                    .color(Color::Disabled),
+                            )
+                            .child(
+                                Label::new(user_rules_text)
+                                    .size(LabelSize::XSmall)
+                                    .color(Color::Muted)
+                                    .truncate()
+                                    .buffer_font(cx)
+                                    .ml_1p5()
+                                    .mr_0p5(),
+                            )
+                            .child(
+                                IconButton::new("open-prompt-library", IconName::ArrowUpRight)
+                                    .shape(ui::IconButtonShape::Square)
+                                    .icon_size(IconSize::XSmall)
+                                    .icon_color(Color::Ignored)
+                                    // TODO: Figure out a way to pass focus handle here so we can display the `OpenRulesLibrary`  keybinding
+                                    .tooltip(Tooltip::text("View User Rules"))
+                                    .on_click(move |_event, window, cx| {
+                                        window.dispatch_action(
+                                            Box::new(OpenRulesLibrary {
+                                                prompt_to_select: first_user_rules_id,
+                                            }),
+                                            cx,
+                                        )
+                                    }),
+                            ),
+                    )
+                })
+                .when_some(rules_file_text, |parent, rules_file_text| {
+                    parent.child(
+                        h_flex()
+                            .w_full()
+                            .child(
+                                Icon::new(IconName::File)
+                                    .size(IconSize::XSmall)
+                                    .color(Color::Disabled),
+                            )
+                            .child(
+                                Label::new(rules_file_text)
+                                    .size(LabelSize::XSmall)
+                                    .color(Color::Muted)
+                                    .buffer_font(cx)
+                                    .ml_1p5()
+                                    .mr_0p5(),
+                            )
+                            .child(
+                                IconButton::new("open-rule", IconName::ArrowUpRight)
+                                    .shape(ui::IconButtonShape::Square)
+                                    .icon_size(IconSize::XSmall)
+                                    .icon_color(Color::Ignored)
+                                    .on_click(cx.listener(Self::handle_open_rules))
+                                    .tooltip(Tooltip::text("View Rules")),
+                            ),
+                    )
+                })
+                .into_any(),
+        )
+    }
+
     fn render_empty_state(&self, cx: &App) -> AnyElement {
         let loading = matches!(&self.thread_state, ThreadState::Loading { .. });