acp: Only use the resumed cwd if it is in the workspace (#48873)

Ben Brandt created

I think we may want to revisit in the future which history sessions we
show, but given the current setup and behavior, I think it makes sense
to only use the cwd if it matches in the worktrees to avoid issues with
reading and editing files outside of there.

- [x] Tests or screenshots needed?
- [x] Code Reviewed
- [x] Manual QA

Release Notes:

- N/A

Change summary

crates/agent_ui/src/acp/thread_view.rs | 322 +++++++++++++++++++++++++--
1 file changed, 296 insertions(+), 26 deletions(-)

Detailed changes

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -423,19 +423,35 @@ impl AcpServerView {
                 .is_single_file()
                 .cmp(&r.read(cx).is_single_file())
         });
-        let root_dir = worktrees
-            .into_iter()
+        let worktree_roots: Vec<Arc<Path>> = worktrees
+            .iter()
             .filter_map(|worktree| {
-                if worktree.read(cx).is_single_file() {
-                    Some(worktree.read(cx).abs_path().parent()?.into())
+                let worktree = worktree.read(cx);
+                if worktree.is_single_file() {
+                    Some(worktree.abs_path().parent()?.into())
                 } else {
-                    Some(worktree.read(cx).abs_path())
+                    Some(worktree.abs_path())
                 }
             })
-            .next();
-        let fallback_cwd = root_dir
-            .clone()
+            .collect();
+        let root_dir = worktree_roots.first().cloned();
+        let session_cwd = resume_thread
+            .as_ref()
+            .and_then(|resume| {
+                resume
+                    .cwd
+                    .as_ref()
+                    .and_then(|cwd| util::paths::normalize_lexically(cwd).ok())
+                    .filter(|cwd| {
+                        worktree_roots
+                            .iter()
+                            .any(|root| cwd.starts_with(root.as_ref()))
+                    })
+                    .map(|path| path.into())
+            })
+            .or_else(|| root_dir.clone())
             .unwrap_or_else(|| paths::home_dir().as_path().into());
+
         let (status_tx, mut status_rx) = watch::channel("Loading…".into());
         let (new_version_available_tx, mut new_version_available_rx) = watch::channel(None);
         let delegate = AgentServerDelegate::new(
@@ -471,25 +487,15 @@ impl AcpServerView {
             let mut resumed_without_history = false;
             let result = if let Some(resume) = resume_thread.clone() {
                 cx.update(|_, cx| {
-                    let session_cwd = resume
-                        .cwd
-                        .clone()
-                        .unwrap_or_else(|| fallback_cwd.as_ref().to_path_buf());
                     if connection.supports_load_session(cx) {
-                        connection.clone().load_session(
-                            resume,
-                            project.clone(),
-                            session_cwd.as_path(),
-                            cx,
-                        )
+                        connection
+                            .clone()
+                            .load_session(resume, project.clone(), &session_cwd, cx)
                     } else if connection.supports_resume_session(cx) {
                         resumed_without_history = true;
-                        connection.clone().resume_session(
-                            resume,
-                            project.clone(),
-                            session_cwd.as_path(),
-                            cx,
-                        )
+                        connection
+                            .clone()
+                            .resume_session(resume, project.clone(), &session_cwd, cx)
                     } else {
                         Task::ready(Err(anyhow!(LoadError::Other(
                             "Loading or resuming sessions is not supported by this agent.".into()
@@ -501,7 +507,7 @@ impl AcpServerView {
                 cx.update(|_, cx| {
                     connection
                         .clone()
-                        .new_session(project.clone(), fallback_cwd.as_ref(), cx)
+                        .new_session(project.clone(), session_cwd.as_ref(), cx)
                 })
                 .log_err()
             };
@@ -2544,12 +2550,14 @@ pub(crate) mod tests {
     use editor::MultiBufferOffset;
     use fs::FakeFs;
     use gpui::{EventEmitter, TestAppContext, VisualTestContext};
+    use parking_lot::Mutex;
     use project::Project;
     use serde_json::json;
     use settings::SettingsStore;
     use std::any::Any;
-    use std::path::Path;
+    use std::path::{Path, PathBuf};
     use std::rc::Rc;
+    use std::sync::Arc;
     use workspace::Item;
 
     use super::*;
@@ -2726,6 +2734,161 @@ pub(crate) mod tests {
         });
     }
 
+    #[gpui::test]
+    async fn test_resume_thread_uses_session_cwd_when_inside_project(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/project",
+            json!({
+                "subdir": {
+                    "file.txt": "hello"
+                }
+            }),
+        )
+        .await;
+        let project = Project::test(fs, [Path::new("/project")], cx).await;
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let connection = CwdCapturingConnection::new();
+        let captured_cwd = connection.captured_cwd.clone();
+
+        let mut session = AgentSessionInfo::new(SessionId::new("session-1"));
+        session.cwd = Some(PathBuf::from("/project/subdir"));
+
+        let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
+        let history = cx.update(|window, cx| cx.new(|cx| AcpThreadHistory::new(None, window, cx)));
+
+        let _thread_view = cx.update(|window, cx| {
+            cx.new(|cx| {
+                AcpServerView::new(
+                    Rc::new(StubAgentServer::new(connection)),
+                    Some(session),
+                    None,
+                    workspace.downgrade(),
+                    project,
+                    Some(thread_store),
+                    None,
+                    history,
+                    window,
+                    cx,
+                )
+            })
+        });
+
+        cx.run_until_parked();
+
+        assert_eq!(
+            captured_cwd.lock().as_deref(),
+            Some(Path::new("/project/subdir")),
+            "Should use session cwd when it's inside the project"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_resume_thread_uses_fallback_cwd_when_outside_project(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/project",
+            json!({
+                "file.txt": "hello"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, [Path::new("/project")], cx).await;
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let connection = CwdCapturingConnection::new();
+        let captured_cwd = connection.captured_cwd.clone();
+
+        let mut session = AgentSessionInfo::new(SessionId::new("session-1"));
+        session.cwd = Some(PathBuf::from("/some/other/path"));
+
+        let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
+        let history = cx.update(|window, cx| cx.new(|cx| AcpThreadHistory::new(None, window, cx)));
+
+        let _thread_view = cx.update(|window, cx| {
+            cx.new(|cx| {
+                AcpServerView::new(
+                    Rc::new(StubAgentServer::new(connection)),
+                    Some(session),
+                    None,
+                    workspace.downgrade(),
+                    project,
+                    Some(thread_store),
+                    None,
+                    history,
+                    window,
+                    cx,
+                )
+            })
+        });
+
+        cx.run_until_parked();
+
+        assert_eq!(
+            captured_cwd.lock().as_deref(),
+            Some(Path::new("/project")),
+            "Should use fallback project cwd when session cwd is outside the project"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_resume_thread_rejects_unnormalized_cwd_outside_project(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/project",
+            json!({
+                "file.txt": "hello"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, [Path::new("/project")], cx).await;
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let connection = CwdCapturingConnection::new();
+        let captured_cwd = connection.captured_cwd.clone();
+
+        let mut session = AgentSessionInfo::new(SessionId::new("session-1"));
+        session.cwd = Some(PathBuf::from("/project/../outside"));
+
+        let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
+        let history = cx.update(|window, cx| cx.new(|cx| AcpThreadHistory::new(None, window, cx)));
+
+        let _thread_view = cx.update(|window, cx| {
+            cx.new(|cx| {
+                AcpServerView::new(
+                    Rc::new(StubAgentServer::new(connection)),
+                    Some(session),
+                    None,
+                    workspace.downgrade(),
+                    project,
+                    Some(thread_store),
+                    None,
+                    history,
+                    window,
+                    cx,
+                )
+            })
+        });
+
+        cx.run_until_parked();
+
+        assert_eq!(
+            captured_cwd.lock().as_deref(),
+            Some(Path::new("/project")),
+            "Should reject unnormalized cwd that resolves outside the project and use fallback cwd"
+        );
+    }
+
     #[gpui::test]
     async fn test_refusal_handling(cx: &mut TestAppContext) {
         init_test(cx);
@@ -3306,6 +3469,113 @@ pub(crate) mod tests {
         }
     }
 
+    #[derive(Clone)]
+    struct CwdCapturingConnection {
+        captured_cwd: Arc<Mutex<Option<PathBuf>>>,
+    }
+
+    impl CwdCapturingConnection {
+        fn new() -> Self {
+            Self {
+                captured_cwd: Arc::new(Mutex::new(None)),
+            }
+        }
+    }
+
+    impl AgentConnection for CwdCapturingConnection {
+        fn telemetry_id(&self) -> SharedString {
+            "cwd-capturing".into()
+        }
+
+        fn new_session(
+            self: Rc<Self>,
+            project: Entity<Project>,
+            cwd: &Path,
+            cx: &mut gpui::App,
+        ) -> Task<gpui::Result<Entity<AcpThread>>> {
+            *self.captured_cwd.lock() = Some(cwd.to_path_buf());
+            let action_log = cx.new(|_| ActionLog::new(project.clone()));
+            let thread = cx.new(|cx| {
+                AcpThread::new(
+                    None,
+                    "CwdCapturingConnection",
+                    self.clone(),
+                    project,
+                    action_log,
+                    SessionId::new("new-session"),
+                    watch::Receiver::constant(
+                        acp::PromptCapabilities::new()
+                            .image(true)
+                            .audio(true)
+                            .embedded_context(true),
+                    ),
+                    cx,
+                )
+            });
+            Task::ready(Ok(thread))
+        }
+
+        fn supports_load_session(&self, _cx: &App) -> bool {
+            true
+        }
+
+        fn load_session(
+            self: Rc<Self>,
+            session: AgentSessionInfo,
+            project: Entity<Project>,
+            cwd: &Path,
+            cx: &mut App,
+        ) -> Task<gpui::Result<Entity<AcpThread>>> {
+            *self.captured_cwd.lock() = Some(cwd.to_path_buf());
+            let action_log = cx.new(|_| ActionLog::new(project.clone()));
+            let thread = cx.new(|cx| {
+                AcpThread::new(
+                    None,
+                    "CwdCapturingConnection",
+                    self.clone(),
+                    project,
+                    action_log,
+                    session.session_id,
+                    watch::Receiver::constant(
+                        acp::PromptCapabilities::new()
+                            .image(true)
+                            .audio(true)
+                            .embedded_context(true),
+                    ),
+                    cx,
+                )
+            });
+            Task::ready(Ok(thread))
+        }
+
+        fn auth_methods(&self) -> &[acp::AuthMethod] {
+            &[]
+        }
+
+        fn authenticate(
+            &self,
+            _method_id: acp::AuthMethodId,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<()>> {
+            Task::ready(Ok(()))
+        }
+
+        fn prompt(
+            &self,
+            _id: Option<acp_thread::UserMessageId>,
+            _params: acp::PromptRequest,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<acp::PromptResponse>> {
+            Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
+        }
+
+        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
+
+        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+            self
+        }
+    }
+
     pub(crate) fn init_test(cx: &mut TestAppContext) {
         cx.update(|cx| {
             let settings_store = SettingsStore::test(cx);