diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index c9e83554959b5e3281a0094c284b5a45ff121d16..ef2a14b6095f7e8f7740f84915c9372e83997e83 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -589,9 +589,10 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} SCCACHE_BUCKET: sccache-zed - name: run_tests::check_wasm::cargo_check_wasm - run: cargo +nightly -Zbuild-std=std,panic_abort check --target wasm32-unknown-unknown -p gpui_platform + run: cargo -Zbuild-std=std,panic_abort check --target wasm32-unknown-unknown -p gpui_platform env: CARGO_TARGET_WASM32_UNKNOWN_UNKNOWN_RUSTFLAGS: -C target-feature=+atomics,+bulk-memory,+mutable-globals + RUSTC_BOOTSTRAP: '1' - name: steps::show_sccache_stats run: sccache --show-stats || true - name: steps::cleanup_cargo_config diff --git a/Cargo.lock b/Cargo.lock index ec232567bfd282c2f4e36e0e2d032c652bb54c08..1ea7b03eac0beca8449fb9b0951e680302cb0bec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,6 +77,7 @@ dependencies = [ "ctor", "fs", "futures 0.3.32", + "git", "gpui", "language", "log", @@ -274,6 +275,7 @@ dependencies = [ "libc", "log", "nix 0.29.0", + "piper", "project", "release_channel", "remote", diff --git a/crates/action_log/Cargo.toml b/crates/action_log/Cargo.toml index 5227a61651012279e83a3b6e3e68b1484acb0f66..6f103c7b44fc8742a2be7285a6f027eab1531dd2 100644 --- a/crates/action_log/Cargo.toml +++ b/crates/action_log/Cargo.toml @@ -33,6 +33,7 @@ watch.workspace = true [dev-dependencies] buffer_diff = { workspace = true, features = ["test-support"] } +git.workspace = true collections = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } ctor.workspace = true diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs index cd17392704e1c6c932a3e4d8716b1c6f37489576..0bb4c0fcaa7ceba1952ebff9803ebbf57b852004 100644 --- a/crates/action_log/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -274,7 +274,6 @@ impl ActionLog { mut buffer_updates: mpsc::UnboundedReceiver<(ChangeAuthor, text::BufferSnapshot)>, cx: &mut AsyncApp, ) -> Result<()> { - let git_store = this.read_with(cx, |this, cx| this.project.read(cx).git_store().clone())?; let git_diff = this .update(cx, |this, cx| { this.project.update(cx, |project, cx| { @@ -283,28 +282,18 @@ impl ActionLog { })? .await .ok(); - let buffer_repo = git_store.read_with(cx, |git_store, cx| { - git_store.repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx) - }); - let (mut git_diff_updates_tx, mut git_diff_updates_rx) = watch::channel(()); - let _repo_subscription = - if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) { - cx.update(|cx| { - let mut old_head = buffer_repo.read(cx).head_commit.clone(); - Some(cx.subscribe(git_diff, move |_, event, cx| { - if let buffer_diff::BufferDiffEvent::DiffChanged { .. } = event { - let new_head = buffer_repo.read(cx).head_commit.clone(); - if new_head != old_head { - old_head = new_head; - git_diff_updates_tx.send(()).ok(); - } - } - })) - }) - } else { - None - }; + let _diff_subscription = if let Some(git_diff) = git_diff.as_ref() { + cx.update(|cx| { + Some(cx.subscribe(git_diff, move |_, event, _cx| { + if matches!(event, buffer_diff::BufferDiffEvent::BaseTextChanged) { + git_diff_updates_tx.send(()).ok(); + } + })) + }) + } else { + None + }; loop { futures::select_biased! { @@ -2714,6 +2703,108 @@ mod tests { assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); } + /// Regression test: when head_commit updates before the BufferDiff's base + /// text does, an intermediate DiffChanged (e.g. from a buffer-edit diff + /// recalculation) must NOT consume the commit signal. The subscription + /// should only fire once the base text itself has changed. + #[gpui::test] + async fn test_keep_edits_on_commit_with_stale_diff_changed(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + ".git": {}, + "file.txt": "aaa\nbbb\nccc\nddd\neee", + }), + ) + .await; + fs.set_head_for_repo( + path!("/project/.git").as_ref(), + &[("file.txt", "aaa\nbbb\nccc\nddd\neee".into())], + "0000000", + ); + cx.run_until_parked(); + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path(path!("/project/file.txt"), cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + // Agent makes an edit: bbb -> BBB + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer.edit([(Point::new(1, 0)..Point::new(1, 3), "BBB")], None, cx); + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + + // Verify the edit is tracked + let hunks = unreviewed_hunks(&action_log, cx); + assert_eq!(hunks.len(), 1); + let hunk = &hunks[0].1; + assert_eq!(hunk.len(), 1); + assert_eq!(hunk[0].old_text, "bbb\n"); + + // Simulate the race condition: update only the HEAD SHA first, + // without changing the committed file contents. This is analogous + // to compute_snapshot updating head_commit before + // reload_buffer_diff_bases has loaded the new base text. + fs.with_git_state(path!("/project/.git").as_ref(), true, |state| { + state.refs.insert("HEAD".into(), "0000001".into()); + }) + .unwrap(); + cx.run_until_parked(); + + // Make a user edit (on a different line) to trigger a buffer diff + // recalculation. This fires DiffChanged while the BufferDiff base + // text is still the OLD text. With the old head_commit-based + // subscription this would "consume" the commit detection. + cx.update(|cx| { + buffer.update(cx, |buffer, cx| { + buffer.edit([(Point::new(3, 0)..Point::new(3, 3), "DDD")], None, cx); + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + + // Now update the committed file contents to match the buffer + // (the agent edit was committed). Keep the same SHA so head_commit + // does NOT change again — this is the second half of the race. + { + use git::repository::repo_path; + fs.with_git_state(path!("/project/.git").as_ref(), true, |state| { + state + .head_contents + .insert(repo_path("file.txt"), "aaa\nBBB\nccc\nDDD\neee".into()); + }) + .unwrap(); + } + cx.run_until_parked(); + + // The agent's edit (bbb -> BBB) should be accepted because the + // committed content now matches. Only the user edit (ddd -> DDD) + // should remain, but since the user edit is tracked as coming from + // the user (ChangeAuthor::User) it would have been rebased into + // the diff base already. So no unreviewed hunks should remain. + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![], + "agent edits should have been accepted after the base text update" + ); + } + #[gpui::test] async fn test_undo_last_reject(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 143f74bfa90a1bb686bc88dc9942816ca42510ee..fcb901347a12798aa8e2e40942f88b47beee011d 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -84,6 +84,12 @@ struct Session { project_id: EntityId, pending_save: Task>, _subscriptions: Vec, + ref_count: usize, +} + +struct PendingSession { + task: Shared, Arc>>>, + ref_count: usize, } pub struct LanguageModels { @@ -245,6 +251,7 @@ impl LanguageModels { pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, + pending_sessions: HashMap, thread_store: Entity, /// Project-specific state keyed by project EntityId projects: HashMap, @@ -278,6 +285,7 @@ impl NativeAgent { Self { sessions: HashMap::default(), + pending_sessions: HashMap::default(), thread_store, projects: HashMap::default(), templates, @@ -316,13 +324,14 @@ impl NativeAgent { ) }); - self.register_session(thread, project_id, cx) + self.register_session(thread, project_id, 1, cx) } fn register_session( &mut self, thread_handle: Entity, project_id: EntityId, + ref_count: usize, cx: &mut Context, ) -> Entity { let connection = Rc::new(NativeAgentConnection(cx.entity())); @@ -388,6 +397,7 @@ impl NativeAgent { project_id, _subscriptions: subscriptions, pending_save: Task::ready(Ok(())), + ref_count, }, ); @@ -926,27 +936,68 @@ impl NativeAgent { project: Entity, cx: &mut Context, ) -> Task>> { - if let Some(session) = self.sessions.get(&id) { + if let Some(session) = self.sessions.get_mut(&id) { + session.ref_count += 1; return Task::ready(Ok(session.acp_thread.clone())); } - let task = self.load_thread(id, project.clone(), cx); - cx.spawn(async move |this, cx| { - let thread = task.await?; - let acp_thread = this.update(cx, |this, cx| { - let project_id = this.get_or_create_project_state(&project, cx); - this.register_session(thread.clone(), project_id, cx) - })?; - let events = thread.update(cx, |thread, cx| thread.replay(cx)); - cx.update(|cx| { - NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) + if let Some(pending) = self.pending_sessions.get_mut(&id) { + pending.ref_count += 1; + let task = pending.task.clone(); + return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) }); + } + + let task = self.load_thread(id.clone(), project.clone(), cx); + let shared_task = cx + .spawn({ + let id = id.clone(); + async move |this, cx| { + let thread = match task.await { + Ok(thread) => thread, + Err(err) => { + this.update(cx, |this, _cx| { + this.pending_sessions.remove(&id); + }) + .ok(); + return Err(Arc::new(err)); + } + }; + let acp_thread = this + .update(cx, |this, cx| { + let project_id = this.get_or_create_project_state(&project, cx); + let ref_count = this + .pending_sessions + .remove(&id) + .map_or(1, |pending| pending.ref_count); + this.register_session(thread.clone(), project_id, ref_count, cx) + }) + .map_err(Arc::new)?; + let events = thread.update(cx, |thread, cx| thread.replay(cx)); + cx.update(|cx| { + NativeAgentConnection::handle_thread_events( + events, + acp_thread.downgrade(), + cx, + ) + }) + .await + .map_err(Arc::new)?; + acp_thread.update(cx, |thread, cx| { + thread.snapshot_completed_plan(cx); + }); + Ok(acp_thread) + } }) - .await?; - acp_thread.update(cx, |thread, cx| { - thread.snapshot_completed_plan(cx); - }); - Ok(acp_thread) - }) + .shared(); + self.pending_sessions.insert( + id, + PendingSession { + task: shared_task.clone(), + ref_count: 1, + }, + ); + + cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) }) } pub fn thread_summary( @@ -968,11 +1019,43 @@ impl NativeAgent { })? .await .context("Failed to generate summary")?; + + this.update(cx, |this, cx| this.close_session(&id, cx))? + .await?; drop(acp_thread); Ok(result) }) } + fn close_session( + &mut self, + session_id: &acp::SessionId, + cx: &mut Context, + ) -> Task> { + let Some(session) = self.sessions.get_mut(session_id) else { + return Task::ready(Ok(())); + }; + + session.ref_count -= 1; + if session.ref_count > 0 { + return Task::ready(Ok(())); + } + + let thread = session.thread.clone(); + self.save_thread(thread, cx); + let Some(session) = self.sessions.remove(session_id) else { + return Task::ready(Ok(())); + }; + let project_id = session.project_id; + + let has_remaining = self.sessions.values().any(|s| s.project_id == project_id); + if !has_remaining { + self.projects.remove(&project_id); + } + + session.pending_save + } + fn save_thread(&mut self, thread: Entity, cx: &mut Context) { if thread.read(cx).is_empty() { return; @@ -1158,6 +1241,7 @@ impl NativeAgentConnection { .get_mut(&session_id) .map(|s| (s.thread.clone(), s.acp_thread.clone())) }) else { + log::error!("Session not found in run_turn: {}", session_id); return Task::ready(Err(anyhow!("Session not found"))); }; log::debug!("Found session for: {}", session_id); @@ -1452,24 +1536,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: &acp::SessionId, cx: &mut App, ) -> Task> { - self.0.update(cx, |agent, cx| { - let thread = agent.sessions.get(session_id).map(|s| s.thread.clone()); - if let Some(thread) = thread { - agent.save_thread(thread, cx); - } - - let Some(session) = agent.sessions.remove(session_id) else { - return Task::ready(Ok(())); - }; - let project_id = session.project_id; - - let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id); - if !has_remaining { - agent.projects.remove(&project_id); - } - - session.pending_save - }) + self.0 + .update(cx, |agent, cx| agent.close_session(session_id, cx)) } fn auth_methods(&self) -> &[acp::AuthMethod] { @@ -1498,6 +1566,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Prompt blocks count: {}", params.prompt.len()); let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else { + log::error!("Session not found in prompt: {}", session_id); + if self.0.read(cx).sessions.contains_key(&session_id) { + log::error!( + "Session found in sessions map, but not in project state: {}", + session_id + ); + } return Task::ready(Err(anyhow::anyhow!("Session not found"))); }; @@ -1812,7 +1887,7 @@ impl NativeThreadEnvironment { .get(&parent_session_id) .map(|s| s.project_id) .context("parent session not found")?; - Ok(agent.register_session(subagent_thread.clone(), project_id, cx)) + Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx)) })??; let depth = current_depth + 1; @@ -3006,6 +3081,241 @@ mod internal_tests { }); } + #[gpui::test] + async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "file.txt": "hello" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = cx.update(|cx| { + NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) + }); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + let model = Arc::new(FakeLanguageModel::default()); + let summary_model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + thread.set_summarization_model(Some(summary_model.clone()), cx); + }); + + let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("world"); + model.end_last_completion_stream(); + send.await.unwrap(); + cx.run_until_parked(); + + let summary = agent.update(cx, |agent, cx| { + agent.thread_summary(session_id.clone(), project.clone(), cx) + }); + cx.run_until_parked(); + + summary_model.send_last_completion_stream_text_chunk("summary"); + summary_model.end_last_completion_stream(); + + assert_eq!(summary.await.unwrap(), "summary"); + cx.run_until_parked(); + + agent.read_with(cx, |agent, _| { + let session = agent + .sessions + .get(&session_id) + .expect("thread_summary should not close the active session"); + assert_eq!( + session.ref_count, 1, + "thread_summary should release its temporary session reference" + ); + }); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + cx.run_until_parked(); + + agent.read_with(cx, |agent, _| { + assert!( + agent.sessions.is_empty(), + "closing the active session after thread_summary should unload it" + ); + }); + } + + #[gpui::test] + async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "file.txt": "hello" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = cx.update(|cx| { + NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) + }); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + let model = cx.update(|cx| { + LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default_model| default_model.model) + .expect("default test model should be available") + }); + let fake_model = model.as_fake(); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + + let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("world"); + fake_model.end_last_completion_stream(); + send.await.unwrap(); + cx.run_until_parked(); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + drop(thread); + drop(acp_thread); + agent.read_with(cx, |agent, _| { + assert!(agent.sessions.is_empty()); + }); + + let first_loaded_thread = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + PathList::new(&[Path::new("")]), + None, + cx, + ) + }); + let second_loaded_thread = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + PathList::new(&[Path::new("")]), + None, + cx, + ) + }); + + let first_loaded_thread = first_loaded_thread.await.unwrap(); + let second_loaded_thread = second_loaded_thread.await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + first_loaded_thread.entity_id(), + second_loaded_thread.entity_id(), + "concurrent loads for the same session should share one AcpThread" + ); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + + agent.read_with(cx, |agent, _| { + assert!( + agent.sessions.contains_key(&session_id), + "closing one loaded session should not drop shared session state" + ); + }); + + let follow_up = second_loaded_thread.update(cx, |thread, cx| { + thread.send(vec!["still there?".into()], cx) + }); + let follow_up = cx.foreground_executor().spawn(follow_up); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("yes"); + fake_model.end_last_completion_stream(); + follow_up.await.unwrap(); + cx.run_until_parked(); + + second_loaded_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + formatdoc! {" + ## User + + hello + + ## Assistant + + world + + ## User + + still there? + + ## Assistant + + yes + + "} + ); + }); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + + cx.run_until_parked(); + + drop(first_loaded_thread); + drop(second_loaded_thread); + agent.read_with(cx, |agent, _| { + assert!(agent.sessions.is_empty()); + }); + } + #[gpui::test] async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) { // Regression test: rapid title changes must not cause a propagation loop diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 5fbf1e821cb4a41f09c433ec05fdde9fbbde1a9f..85b206248c7e4ccd039bc92e911891a8cf830727 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -68,4 +68,7 @@ indoc.workspace = true acp_thread = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } gpui_tokio.workspace = true +piper = "0.2" +project = { workspace = true, features = ["test-support"] } reqwest_client = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 54c24c91c89cde8faa4ab351aa8990b92b578050..dae7888e65a01b09699aff59a758d200c03087e3 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -9,6 +9,8 @@ use anyhow::anyhow; use collections::HashMap; use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _}; use futures::AsyncBufReadExt as _; +use futures::FutureExt as _; +use futures::future::Shared; use futures::io::BufReader; use project::agent_server_store::{AgentServerCommand, AgentServerStore}; use project::{AgentId, Project}; @@ -25,6 +27,8 @@ use util::ResultExt as _; use util::path_list::PathList; use util::process::Child; +use std::sync::Arc; + use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity}; @@ -45,19 +49,31 @@ pub struct AcpConnection { telemetry_id: SharedString, connection: Rc, sessions: Rc>>, + pending_sessions: Rc>>, auth_methods: Vec, agent_server_store: WeakEntity, agent_capabilities: acp::AgentCapabilities, default_mode: Option, default_model: Option, default_config_options: HashMap, - child: Child, + child: Option, session_list: Option>, _io_task: Task>, _wait_task: Task>, _stderr_task: Task>, } +struct PendingAcpSession { + task: Shared, Arc>>>, + ref_count: usize, +} + +struct SessionConfigResponse { + modes: Option, + models: Option, + config_options: Option>, +} + struct ConfigOptions { config_options: Rc>>, tx: Rc>>, @@ -81,6 +97,7 @@ pub struct AcpSession { models: Option>>, session_modes: Option>>, config_options: Option, + ref_count: usize, } pub struct AcpSessionList { @@ -393,6 +410,7 @@ impl AcpConnection { connection, telemetry_id, sessions, + pending_sessions: Rc::new(RefCell::new(HashMap::default())), agent_capabilities: response.agent_capabilities, default_mode, default_model, @@ -401,7 +419,7 @@ impl AcpConnection { _io_task: io_task, _wait_task: wait_task, _stderr_task: stderr_task, - child, + child: Some(child), }) } @@ -409,6 +427,143 @@ impl AcpConnection { &self.agent_capabilities.prompt_capabilities } + #[cfg(test)] + fn new_for_test( + connection: Rc, + sessions: Rc>>, + agent_capabilities: acp::AgentCapabilities, + agent_server_store: WeakEntity, + io_task: Task>, + _cx: &mut App, + ) -> Self { + Self { + id: AgentId::new("test"), + telemetry_id: "test".into(), + connection, + sessions, + pending_sessions: Rc::new(RefCell::new(HashMap::default())), + auth_methods: vec![], + agent_server_store, + agent_capabilities, + default_mode: None, + default_model: None, + default_config_options: HashMap::default(), + child: None, + session_list: None, + _io_task: io_task, + _wait_task: Task::ready(Ok(())), + _stderr_task: Task::ready(Ok(())), + } + } + + fn open_or_create_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + work_dirs: PathList, + title: Option, + rpc_call: impl FnOnce( + Rc, + acp::SessionId, + PathBuf, + ) + -> futures::future::LocalBoxFuture<'static, Result> + + 'static, + cx: &mut App, + ) -> Task>> { + if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { + session.ref_count += 1; + if let Some(thread) = session.thread.upgrade() { + return Task::ready(Ok(thread)); + } + } + + if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) { + pending.ref_count += 1; + let task = pending.task.clone(); + return cx + .foreground_executor() + .spawn(async move { task.await.map_err(|err| anyhow!(err)) }); + } + + // TODO: remove this once ACP supports multiple working directories + let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { + return Task::ready(Err(anyhow!("Working directory cannot be empty"))); + }; + + let shared_task = cx + .spawn({ + let session_id = session_id.clone(); + let this = self.clone(); + async move |cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread: Entity = cx.new(|cx| { + AcpThread::new( + None, + title, + Some(work_dirs), + this.clone(), + project, + action_log, + session_id.clone(), + watch::Receiver::constant( + this.agent_capabilities.prompt_capabilities.clone(), + ), + cx, + ) + }); + + let response = + match rpc_call(this.connection.clone(), session_id.clone(), cwd).await { + Ok(response) => response, + Err(err) => { + this.pending_sessions.borrow_mut().remove(&session_id); + return Err(Arc::new(err)); + } + }; + + let (modes, models, config_options) = + config_state(response.modes, response.models, response.config_options); + + if let Some(config_opts) = config_options.as_ref() { + this.apply_default_config_options(&session_id, config_opts, cx); + } + + let ref_count = this + .pending_sessions + .borrow_mut() + .remove(&session_id) + .map_or(1, |pending| pending.ref_count); + + this.sessions.borrow_mut().insert( + session_id, + AcpSession { + thread: thread.downgrade(), + suppress_abort_err: false, + session_modes: modes, + models, + config_options: config_options.map(ConfigOptions::new), + ref_count, + }, + ); + + Ok(thread) + } + }) + .shared(); + + self.pending_sessions.borrow_mut().insert( + session_id, + PendingAcpSession { + task: shared_task.clone(), + ref_count: 1, + }, + ); + + cx.foreground_executor() + .spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) }) + } + fn apply_default_config_options( &self, session_id: &acp::SessionId, @@ -508,7 +663,9 @@ impl AcpConnection { impl Drop for AcpConnection { fn drop(&mut self) { - self.child.kill().log_err(); + if let Some(ref mut child) = self.child { + child.kill().log_err(); + } } } @@ -700,6 +857,7 @@ impl AgentConnection for AcpConnection { session_modes: modes, models, config_options: config_options.map(ConfigOptions::new), + ref_count: 1, }, ); @@ -731,68 +889,30 @@ impl AgentConnection for AcpConnection { "Loading sessions is not supported by this agent.".into() )))); } - // TODO: remove this once ACP supports multiple working directories - let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { - return Task::ready(Err(anyhow!("Working directory cannot be empty"))); - }; let mcp_servers = mcp_servers_for_project(&project, cx); - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread: Entity = cx.new(|cx| { - AcpThread::new( - None, - title, - Some(work_dirs.clone()), - self.clone(), - project, - action_log, - session_id.clone(), - watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()), - cx, - ) - }); - - self.sessions.borrow_mut().insert( - session_id.clone(), - AcpSession { - thread: thread.downgrade(), - suppress_abort_err: false, - session_modes: None, - models: None, - config_options: None, + self.open_or_create_session( + session_id, + project, + work_dirs, + title, + move |connection, session_id, cwd| { + Box::pin(async move { + let response = connection + .load_session( + acp::LoadSessionRequest::new(session_id, cwd).mcp_servers(mcp_servers), + ) + .await + .map_err(map_acp_error)?; + Ok(SessionConfigResponse { + modes: response.modes, + models: response.models, + config_options: response.config_options, + }) + }) }, - ); - - cx.spawn(async move |cx| { - let response = match self - .connection - .load_session( - acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers), - ) - .await - { - Ok(response) => response, - Err(err) => { - self.sessions.borrow_mut().remove(&session_id); - return Err(map_acp_error(err)); - } - }; - - let (modes, models, config_options) = - config_state(response.modes, response.models, response.config_options); - - if let Some(config_opts) = config_options.as_ref() { - self.apply_default_config_options(&session_id, config_opts, cx); - } - - if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { - session.session_modes = modes; - session.models = models; - session.config_options = config_options.map(ConfigOptions::new); - } - - Ok(thread) - }) + cx, + ) } fn resume_session( @@ -813,69 +933,31 @@ impl AgentConnection for AcpConnection { "Resuming sessions is not supported by this agent.".into() )))); } - // TODO: remove this once ACP supports multiple working directories - let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { - return Task::ready(Err(anyhow!("Working directory cannot be empty"))); - }; let mcp_servers = mcp_servers_for_project(&project, cx); - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread: Entity = cx.new(|cx| { - AcpThread::new( - None, - title, - Some(work_dirs), - self.clone(), - project, - action_log, - session_id.clone(), - watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()), - cx, - ) - }); - - self.sessions.borrow_mut().insert( - session_id.clone(), - AcpSession { - thread: thread.downgrade(), - suppress_abort_err: false, - session_modes: None, - models: None, - config_options: None, + self.open_or_create_session( + session_id, + project, + work_dirs, + title, + move |connection, session_id, cwd| { + Box::pin(async move { + let response = connection + .resume_session( + acp::ResumeSessionRequest::new(session_id, cwd) + .mcp_servers(mcp_servers), + ) + .await + .map_err(map_acp_error)?; + Ok(SessionConfigResponse { + modes: response.modes, + models: response.models, + config_options: response.config_options, + }) + }) }, - ); - - cx.spawn(async move |cx| { - let response = match self - .connection - .resume_session( - acp::ResumeSessionRequest::new(session_id.clone(), cwd) - .mcp_servers(mcp_servers), - ) - .await - { - Ok(response) => response, - Err(err) => { - self.sessions.borrow_mut().remove(&session_id); - return Err(map_acp_error(err)); - } - }; - - let (modes, models, config_options) = - config_state(response.modes, response.models, response.config_options); - - if let Some(config_opts) = config_options.as_ref() { - self.apply_default_config_options(&session_id, config_opts, cx); - } - - if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { - session.session_modes = modes; - session.models = models; - session.config_options = config_options.map(ConfigOptions::new); - } - - Ok(thread) - }) + cx, + ) } fn supports_close_session(&self) -> bool { @@ -893,12 +975,24 @@ impl AgentConnection for AcpConnection { )))); } + let mut sessions = self.sessions.borrow_mut(); + let Some(session) = sessions.get_mut(session_id) else { + return Task::ready(Ok(())); + }; + + session.ref_count -= 1; + if session.ref_count > 0 { + return Task::ready(Ok(())); + } + + sessions.remove(session_id); + drop(sessions); + let conn = self.connection.clone(); let session_id = session_id.clone(); cx.foreground_executor().spawn(async move { - conn.close_session(acp::CloseSessionRequest::new(session_id.clone())) + conn.close_session(acp::CloseSessionRequest::new(session_id)) .await?; - self.sessions.borrow_mut().remove(&session_id); Ok(()) }) } @@ -1112,6 +1206,8 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error { #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use super::*; #[test] @@ -1240,6 +1336,241 @@ mod tests { ); assert_eq!(task.label, "Login"); } + + struct FakeAcpAgent { + load_session_count: Arc, + close_session_count: Arc, + } + + #[async_trait::async_trait(?Send)] + impl acp::Agent for FakeAcpAgent { + async fn initialize( + &self, + args: acp::InitializeRequest, + ) -> acp::Result { + Ok( + acp::InitializeResponse::new(args.protocol_version).agent_capabilities( + acp::AgentCapabilities::default() + .load_session(true) + .session_capabilities( + acp::SessionCapabilities::default() + .close(acp::SessionCloseCapabilities::new()), + ), + ), + ) + } + + async fn authenticate( + &self, + _: acp::AuthenticateRequest, + ) -> acp::Result { + Ok(Default::default()) + } + + async fn new_session( + &self, + _: acp::NewSessionRequest, + ) -> acp::Result { + Ok(acp::NewSessionResponse::new(acp::SessionId::new("unused"))) + } + + async fn prompt(&self, _: acp::PromptRequest) -> acp::Result { + Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) + } + + async fn cancel(&self, _: acp::CancelNotification) -> acp::Result<()> { + Ok(()) + } + + async fn load_session( + &self, + _: acp::LoadSessionRequest, + ) -> acp::Result { + self.load_session_count.fetch_add(1, Ordering::SeqCst); + Ok(acp::LoadSessionResponse::new()) + } + + async fn close_session( + &self, + _: acp::CloseSessionRequest, + ) -> acp::Result { + self.close_session_count.fetch_add(1, Ordering::SeqCst); + Ok(acp::CloseSessionResponse::new()) + } + } + + async fn connect_fake_agent( + cx: &mut gpui::TestAppContext, + ) -> ( + Rc, + Entity, + Arc, + Arc, + Task>, + ) { + cx.update(|cx| { + let store = settings::SettingsStore::test(cx); + cx.set_global(store); + }); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/", serde_json::json!({ "a": {} })).await; + let project = project::Project::test(fs, [std::path::Path::new("/a")], cx).await; + + let load_count = Arc::new(AtomicUsize::new(0)); + let close_count = Arc::new(AtomicUsize::new(0)); + + let (c2a_reader, c2a_writer) = piper::pipe(4096); + let (a2c_reader, a2c_writer) = piper::pipe(4096); + + let sessions: Rc>> = + Rc::new(RefCell::new(HashMap::default())); + let session_list_container: Rc>>> = + Rc::new(RefCell::new(None)); + + let foreground = cx.foreground_executor().clone(); + + let client_delegate = ClientDelegate { + sessions: sessions.clone(), + session_list: session_list_container, + cx: cx.to_async(), + }; + + let (client_conn, client_io_task) = + acp::ClientSideConnection::new(client_delegate, c2a_writer, a2c_reader, { + let foreground = foreground.clone(); + move |fut| { + foreground.spawn(fut).detach(); + } + }); + + let fake_agent = FakeAcpAgent { + load_session_count: load_count.clone(), + close_session_count: close_count.clone(), + }; + + let (_, agent_io_task) = + acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, { + let foreground = foreground.clone(); + move |fut| { + foreground.spawn(fut).detach(); + } + }); + + let client_io_task = cx.background_spawn(client_io_task); + let agent_io_task = cx.background_spawn(agent_io_task); + + let response = client_conn + .initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1)) + .await + .expect("failed to initialize ACP connection"); + + let agent_capabilities = response.agent_capabilities; + + let agent_server_store = + project.read_with(cx, |project, _| project.agent_server_store().downgrade()); + + let connection = cx.update(|cx| { + AcpConnection::new_for_test( + Rc::new(client_conn), + sessions, + agent_capabilities, + agent_server_store, + client_io_task, + cx, + ) + }); + + let keep_agent_alive = cx.background_spawn(async move { + agent_io_task.await.ok(); + anyhow::Ok(()) + }); + + ( + Rc::new(connection), + project, + load_count, + close_count, + keep_agent_alive, + ) + } + + #[gpui::test] + async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) { + let (connection, project, load_count, close_count, _keep_agent_alive) = + connect_fake_agent(cx).await; + + let session_id = acp::SessionId::new("session-1"); + let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]); + + // Load the same session twice concurrently — the second call should join + // the pending task rather than issuing a second ACP load_session RPC. + let first_load = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs.clone(), + None, + cx, + ) + }); + let second_load = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs.clone(), + None, + cx, + ) + }); + + let first_thread = first_load.await.expect("first load failed"); + let second_thread = second_load.await.expect("second load failed"); + cx.run_until_parked(); + + assert_eq!( + first_thread.entity_id(), + second_thread.entity_id(), + "concurrent loads for the same session should share one AcpThread" + ); + assert_eq!( + load_count.load(Ordering::SeqCst), + 1, + "underlying ACP load_session should be called exactly once for concurrent loads" + ); + + // The session has ref_count 2. The first close should not send the ACP + // close_session RPC — the session is still referenced. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .expect("first close failed"); + + assert_eq!( + close_count.load(Ordering::SeqCst), + 0, + "ACP close_session should not be sent while ref_count > 0" + ); + assert!( + connection.sessions.borrow().contains_key(&session_id), + "session should still be tracked after first close" + ); + + // The second close drops ref_count to 0 — now the ACP RPC must be sent. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .expect("second close failed"); + cx.run_until_parked(); + + assert_eq!( + close_count.load(Ordering::SeqCst), + 1, + "ACP close_session should be sent exactly once when ref_count reaches 0" + ); + assert!( + !connection.sessions.borrow().contains_key(&session_id), + "session should be removed after final close" + ); + } } fn mcp_servers_for_project(project: &Entity, cx: &App) -> Vec { diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 413bad667825f05ad6b399677877fb2ec99cb7c9..005476834089ad095aab5784fca6881f6124d9ba 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -246,7 +246,7 @@ pub fn init(cx: &mut App) { .and_then(|conversation| { conversation .read(cx) - .active_thread() + .root_thread_view() .map(|r| r.read(cx).thread.clone()) }); @@ -763,7 +763,6 @@ pub struct AgentPanel { agent_layout_onboarding: Entity, agent_layout_onboarding_dismissed: AtomicBool, selected_agent: Agent, - pending_thread_loads: usize, worktree_creation_status: Option<(EntityId, WorktreeCreationStatus)>, _thread_view_subscription: Option, _active_thread_focus_subscription: Option, @@ -800,7 +799,7 @@ impl AgentPanel { Some( conversation .read(cx) - .root_thread(cx)? + .root_thread_view()? .read(cx) .thread .read(cx) @@ -1172,7 +1171,6 @@ impl AgentPanel { agent_layout_onboarding, thread_store, selected_agent: Agent::default(), - pending_thread_loads: 0, worktree_creation_status: None, _thread_view_subscription: None, _active_thread_focus_subscription: None, @@ -1328,16 +1326,37 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) -> Entity { - if let Some(draft) = &self.draft_thread { - return draft.clone(); - } - let agent = if self.project.read(cx).is_via_collab() { + let desired_agent = if self.project.read(cx).is_via_collab() { Agent::NativeAgent } else { self.selected_agent.clone() }; - let thread = - self.create_agent_thread(agent, None, None, None, None, "agent_panel", window, cx); + if let Some(draft) = &self.draft_thread { + let agent_matches = *draft.read(cx).agent_key() == desired_agent; + let has_editor_content = draft.read(cx).root_thread_view().is_some_and(|tv| { + !tv.read(cx) + .message_editor + .read(cx) + .text(cx) + .trim() + .is_empty() + }); + if agent_matches || has_editor_content { + return draft.clone(); + } + self.draft_thread = None; + self._draft_editor_observation = None; + } + let thread = self.create_agent_thread( + desired_agent, + None, + None, + None, + None, + "agent_panel", + window, + cx, + ); self.draft_thread = Some(thread.conversation_view.clone()); self.observe_draft_editor(&thread.conversation_view, cx); thread.conversation_view @@ -1348,7 +1367,7 @@ impl AgentPanel { conversation_view: &Entity, cx: &mut Context, ) { - if let Some(acp_thread) = conversation_view.read(cx).root_acp_thread(cx) { + if let Some(acp_thread) = conversation_view.read(cx).root_thread(cx) { self._draft_editor_observation = Some(cx.subscribe( &acp_thread, |this, _, e: &AcpThreadEvent, cx| { @@ -1360,7 +1379,7 @@ impl AgentPanel { } else { let cv = conversation_view.clone(); self._draft_editor_observation = Some(cx.observe(&cv, |this, cv, cx| { - if cv.read(cx).root_acp_thread(cx).is_some() { + if cv.read(cx).root_thread(cx).is_some() { this.observe_draft_editor(&cv, cx); } })); @@ -1448,7 +1467,7 @@ impl AgentPanel { } _ => None, })?; - let tv = cv.read(cx).active_thread()?; + let tv = cv.read(cx).root_thread_view()?; let text = tv.read(cx).message_editor.read(cx).text(cx); if text.trim().is_empty() { None @@ -1470,7 +1489,7 @@ impl AgentPanel { _ => None, }); let Some(cv) = cv else { return }; - let Some(tv) = cv.read(cx).active_thread() else { + let Some(tv) = cv.read(cx).root_thread_view() else { return; }; let editor = tv.read(cx).message_editor.clone(); @@ -1608,7 +1627,7 @@ impl AgentPanel { return; }; - let Some(active_thread) = conversation_view.read(cx).active_thread().cloned() else { + let Some(active_thread) = conversation_view.read(cx).root_thread_view() else { return; }; @@ -2078,15 +2097,14 @@ impl AgentPanel { pub fn active_thread_view(&self, cx: &App) -> Option> { let server_view = self.active_conversation_view()?; - server_view.read(cx).active_thread().cloned() + server_view.read(cx).root_thread_view() } pub fn active_agent_thread(&self, cx: &App) -> Option> { match &self.base_view { - BaseView::AgentThread { conversation_view } => conversation_view - .read(cx) - .active_thread() - .map(|r| r.read(cx).thread.clone()), + BaseView::AgentThread { conversation_view } => { + conversation_view.read(cx).root_thread(cx) + } _ => None, } } @@ -2103,7 +2121,7 @@ impl AgentPanel { for conversation_view in conversation_views { if *thread_id == conversation_view.read(cx).thread_id { - if let Some(thread_view) = conversation_view.read(cx).root_thread_view(cx) { + if let Some(thread_view) = conversation_view.read(cx).root_thread_view() { thread_view.update(cx, |view, cx| view.cancel_generation(cx)); return true; } @@ -2118,13 +2136,13 @@ impl AgentPanel { let mut views = Vec::new(); if let Some(server_view) = self.active_conversation_view() { - if let Some(thread_view) = server_view.read(cx).root_thread(cx) { + if let Some(thread_view) = server_view.read(cx).root_thread_view() { views.push(thread_view); } } for server_view in self.retained_threads.values() { - if let Some(thread_view) = server_view.read(cx).root_thread(cx) { + if let Some(thread_view) = server_view.read(cx).root_thread_view() { views.push(thread_view); } } @@ -2148,6 +2166,10 @@ impl AgentPanel { }); } + if self.project.read(cx).is_via_collab() { + return; + } + // Update metadata store so threads' path lists stay in sync with // the project's current worktrees. Without this, threads saved // before a worktree was added would have stale paths and not @@ -2191,7 +2213,7 @@ impl AgentPanel { .retained_threads .iter() .filter(|(_id, view)| { - let Some(thread_view) = view.read(cx).root_thread(cx) else { + let Some(thread_view) = view.read(cx).root_thread_view() else { return true; }; let thread = thread_view.read(cx).thread.read(cx); @@ -2418,7 +2440,7 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) -> Option { - server_view.read(cx).active_thread().cloned().map(|tv| { + server_view.read(cx).root_thread_view().map(|tv| { cx.subscribe_in( &tv, window, @@ -2515,10 +2537,6 @@ impl AgentPanel { ); } - pub fn begin_loading_thread(&mut self) { - self.pending_thread_loads += 1; - } - pub fn load_agent_thread( &mut self, agent: Agent, @@ -2530,7 +2548,6 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - self.pending_thread_loads = self.pending_thread_loads.saturating_sub(1); if let Some(store) = ThreadMetadataStore::try_global(cx) { let thread_id = store .read(cx) @@ -2545,8 +2562,9 @@ impl AgentPanel { let has_session = |cv: &Entity| -> bool { cv.read(cx) - .active_thread() - .is_some_and(|tv| tv.read(cx).thread.read(cx).session_id() == &session_id) + .root_session_id + .as_ref() + .is_some_and(|id| id == &session_id) }; // Check if the active view already has this session. @@ -3852,12 +3870,12 @@ impl AgentPanel { VisibleSurface::AgentThread(conversation_view) => { let server_view_ref = conversation_view.read(cx); let is_generating_title = server_view_ref.as_native_thread(cx).is_some() - && server_view_ref.root_thread(cx).map_or(false, |tv| { + && server_view_ref.root_thread_view().map_or(false, |tv| { tv.read(cx).thread.read(cx).has_provisional_title() }); if let Some(title_editor) = server_view_ref - .root_thread(cx) + .root_thread_view() .map(|r| r.read(cx).title_editor.clone()) { if is_generating_title { @@ -6921,7 +6939,7 @@ mod tests { // Verify thread A's (background) work_dirs are also updated. let updated_a_paths = panel.read_with(&cx, |panel, cx| { let bg_view = panel.retained_threads.get(&thread_id_a).unwrap(); - let root_thread = bg_view.read(cx).root_thread(cx).unwrap(); + let root_thread = bg_view.read(cx).root_thread_view().unwrap(); root_thread .read(cx) .thread @@ -6941,7 +6959,7 @@ mod tests { // Verify thread idle C was also updated. let updated_c_paths = panel.read_with(&cx, |panel, cx| { let bg_view = panel.retained_threads.get(&thread_id_c).unwrap(); - let root_thread = bg_view.read(cx).root_thread(cx).unwrap(); + let root_thread = bg_view.read(cx).root_thread_view().unwrap(); root_thread .read(cx) .thread @@ -6995,7 +7013,7 @@ mod tests { let after_remove_a = panel.read_with(&cx, |panel, cx| { let bg_view = panel.retained_threads.get(&thread_id_a).unwrap(); - let root_thread = bg_view.read(cx).root_thread(cx).unwrap(); + let root_thread = bg_view.read(cx).root_thread_view().unwrap(); root_thread .read(cx) .thread @@ -7215,6 +7233,95 @@ mod tests { }); } + #[gpui::test] + async fn test_draft_replaced_when_selected_agent_changes(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| { + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + ::set_global(fs.clone(), cx); + }); + + let project = Project::test(fs.clone(), [], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _cx| { + multi_workspace.workspace().clone() + }) + .unwrap(); + + workspace.update(cx, |workspace, _cx| { + workspace.set_random_database_id(); + }); + + let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx); + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let panel = cx.new(|cx| AgentPanel::new(workspace, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + // Create a draft with the default NativeAgent. + panel.update_in(cx, |panel, window, cx| { + panel.activate_draft(true, window, cx); + }); + + let first_draft_id = panel.read_with(cx, |panel, cx| { + assert!(panel.draft_thread.is_some()); + assert_eq!(panel.selected_agent, Agent::NativeAgent); + let draft = panel.draft_thread.as_ref().unwrap(); + assert_eq!(*draft.read(cx).agent_key(), Agent::NativeAgent); + draft.entity_id() + }); + + // Switch selected_agent to a custom agent, then activate_draft again. + // The stale NativeAgent draft should be replaced. + let custom_agent = Agent::Custom { + id: "my-custom-agent".into(), + }; + panel.update_in(cx, |panel, window, cx| { + panel.selected_agent = custom_agent.clone(); + panel.activate_draft(true, window, cx); + }); + + panel.read_with(cx, |panel, cx| { + let draft = panel.draft_thread.as_ref().expect("draft should exist"); + assert_ne!( + draft.entity_id(), + first_draft_id, + "a new draft should have been created" + ); + assert_eq!( + *draft.read(cx).agent_key(), + custom_agent, + "the new draft should use the custom agent" + ); + }); + + // Calling activate_draft again with the same agent should return the + // cached draft (no replacement). + let second_draft_id = panel.read_with(cx, |panel, _cx| { + panel.draft_thread.as_ref().unwrap().entity_id() + }); + + panel.update_in(cx, |panel, window, cx| { + panel.activate_draft(true, window, cx); + }); + + panel.read_with(cx, |panel, _cx| { + assert_eq!( + panel.draft_thread.as_ref().unwrap().entity_id(), + second_draft_id, + "draft should be reused when the agent has not changed" + ); + }); + } + #[gpui::test] async fn test_rollback_all_succeed_returns_ok(cx: &mut TestAppContext) { init_test(cx); @@ -8059,4 +8166,382 @@ mod tests { ); }); } + + /// Connection that tracks closed sessions and detects prompts against + /// sessions that no longer exist, used to reproduce session disassociation. + #[derive(Clone, Default)] + struct DisassociationTrackingConnection { + next_session_number: Arc>, + sessions: Arc>>, + closed_sessions: Arc>>, + missing_prompt_sessions: Arc>>, + } + + impl DisassociationTrackingConnection { + fn new() -> Self { + Self::default() + } + + fn create_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + work_dirs: PathList, + title: Option, + cx: &mut App, + ) -> Entity { + self.sessions.lock().insert(session_id.clone()); + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + cx.new(|cx| { + AcpThread::new( + None, + title, + Some(work_dirs), + self, + project, + action_log, + session_id, + watch::Receiver::constant( + acp::PromptCapabilities::new() + .image(true) + .audio(true) + .embedded_context(true), + ), + cx, + ) + }) + } + } + + impl AgentConnection for DisassociationTrackingConnection { + fn agent_id(&self) -> AgentId { + agent::ZED_AGENT_ID.clone() + } + + fn telemetry_id(&self) -> SharedString { + "disassociation-tracking-test".into() + } + + fn new_session( + self: Rc, + project: Entity, + work_dirs: PathList, + cx: &mut App, + ) -> Task>> { + let session_id = { + let mut next_session_number = self.next_session_number.lock(); + let session_id = acp::SessionId::new(format!( + "disassociation-tracking-session-{}", + *next_session_number + )); + *next_session_number += 1; + session_id + }; + let thread = self.create_session(session_id, project, work_dirs, None, cx); + Task::ready(Ok(thread)) + } + + fn supports_load_session(&self) -> bool { + true + } + + fn load_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + work_dirs: PathList, + title: Option, + cx: &mut App, + ) -> Task>> { + let thread = self.create_session(session_id, project, work_dirs, title, cx); + thread.update(cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::UserMessageChunk(acp::ContentChunk::new( + "Restored user message".into(), + )), + cx, + ) + .expect("restored user message should be applied"); + thread + .handle_session_update( + acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new( + "Restored assistant message".into(), + )), + cx, + ) + .expect("restored assistant message should be applied"); + }); + Task::ready(Ok(thread)) + } + + fn supports_close_session(&self) -> bool { + true + } + + fn close_session( + self: Rc, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Task> { + self.sessions.lock().remove(session_id); + self.closed_sessions.lock().push(session_id.clone()); + Task::ready(Ok(())) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn prompt( + &self, + _id: UserMessageId, + params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + if !self.sessions.lock().contains(¶ms.session_id) { + self.missing_prompt_sessions.lock().push(params.session_id); + return Task::ready(Err(anyhow!("Session not found"))); + } + + Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} + + fn into_any(self: Rc) -> Rc { + self + } + } + + async fn setup_workspace_panel( + cx: &mut TestAppContext, + ) -> (Entity, Entity, VisualTestContext) { + init_test(cx); + cx.update(|cx| { + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |mw, _cx| mw.workspace().clone()) + .unwrap(); + + let mut cx = VisualTestContext::from_window(multi_workspace.into(), cx); + + let panel = workspace.update_in(&mut cx, |workspace, window, cx| { + let panel = cx.new(|cx| AgentPanel::new(workspace, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + (workspace, panel, cx) + } + + /// Reproduces the retained-thread reset race: + /// + /// 1. Thread A is active and Connected. + /// 2. User switches to thread B → A goes to retained_threads. + /// 3. A thread_error is set on retained A's thread view. + /// 4. AgentServersUpdated fires → retained A's handle_agent_servers_updated + /// sees has_thread_error=true → calls reset() → close_all_sessions → + /// session X removed, state = Loading. + /// 5. User reopens thread X via open_thread → load_agent_thread checks + /// retained A's has_session → returns false (state is Loading) → + /// creates new ConversationView C. + /// 6. Both A's reload task and C's load task complete → both call + /// load_session(X) → both get Connected with session X. + /// 7. A is eventually cleaned up → on_release → close_all_sessions → + /// removes session X. + /// 8. C sends → "Session not found". + #[gpui::test] + async fn test_retained_thread_reset_race_disassociates_session(cx: &mut TestAppContext) { + let (_workspace, panel, mut cx) = setup_workspace_panel(cx).await; + cx.run_until_parked(); + + let connection = DisassociationTrackingConnection::new(); + panel.update(&mut cx, |panel, cx| { + panel.connection_store.update(cx, |store, cx| { + store.restart_connection( + Agent::Stub, + Rc::new(StubAgentServer::new(connection.clone())), + cx, + ); + }); + }); + cx.run_until_parked(); + + // Step 1: Open thread A and send a message. + panel.update_in(&mut cx, |panel, window, cx| { + panel.external_thread( + Some(Agent::Stub), + None, + None, + None, + None, + true, + "agent_panel", + window, + cx, + ); + }); + cx.run_until_parked(); + send_message(&panel, &mut cx); + + let session_id_a = active_session_id(&panel, &cx); + let _thread_id_a = active_thread_id(&panel, &cx); + + // Step 2: Open thread B → A goes to retained_threads. + panel.update_in(&mut cx, |panel, window, cx| { + panel.external_thread( + Some(Agent::Stub), + None, + None, + None, + None, + true, + "agent_panel", + window, + cx, + ); + }); + cx.run_until_parked(); + send_message(&panel, &mut cx); + + // Confirm A is retained. + panel.read_with(&cx, |panel, _cx| { + assert!( + panel.retained_threads.contains_key(&_thread_id_a), + "thread A should be in retained_threads after switching to B" + ); + }); + + // Step 3: Set a thread_error on retained A's active thread view. + // This simulates an API error that occurred before the user switched + // away, or a transient failure. + let retained_conversation_a = panel.read_with(&cx, |panel, _cx| { + panel + .retained_threads + .get(&_thread_id_a) + .expect("thread A should be retained") + .clone() + }); + retained_conversation_a.update(&mut cx, |conversation, cx| { + if let Some(thread_view) = conversation.active_thread() { + thread_view.update(cx, |view, cx| { + view.handle_thread_error( + crate::conversation_view::ThreadError::Other { + message: "simulated error".into(), + acp_error_code: None, + }, + cx, + ); + }); + } + }); + + // Confirm the thread error is set. + retained_conversation_a.read_with(&cx, |conversation, cx| { + let connected = conversation.as_connected().expect("should be connected"); + assert!( + connected.has_thread_error(cx), + "retained A should have a thread error" + ); + }); + + // Step 4: Emit AgentServersUpdated → retained A's + // handle_agent_servers_updated sees has_thread_error=true, + // calls reset(), which closes session X and sets state=Loading. + // + // Critically, we do NOT call run_until_parked between the emit + // and open_thread. The emit's synchronous effects (event delivery + // → reset() → close_all_sessions → state=Loading) happen during + // the update's flush_effects. But the async reload task spawned + // by initial_state has NOT been polled yet. + panel.update(&mut cx, |panel, cx| { + panel.project.update(cx, |project, cx| { + project + .agent_server_store() + .update(cx, |_store, cx| cx.emit(project::AgentServersUpdated)); + }); + }); + // After this update returns, the retained ConversationView is in + // Loading state (reset ran synchronously), but its async reload + // task hasn't executed yet. + + // Step 5: Immediately open thread X via open_thread, BEFORE + // the retained view's async reload completes. load_agent_thread + // checks retained A's has_session → returns false (state is + // Loading) → creates a NEW ConversationView C for session X. + panel.update_in(&mut cx, |panel, window, cx| { + panel.open_thread(session_id_a.clone(), None, None, window, cx); + }); + + // NOW settle everything: both async tasks (A's reload and C's load) + // complete, both register session X. + cx.run_until_parked(); + + // Verify session A is the active session via C. + panel.read_with(&cx, |panel, cx| { + let active_session = panel + .active_agent_thread(cx) + .map(|t| t.read(cx).session_id().clone()); + assert_eq!( + active_session, + Some(session_id_a.clone()), + "session A should be the active session after open_thread" + ); + }); + + // Step 6: Force the retained ConversationView A to be dropped + // while the active view (C) still has the same session. + // We can't use remove_thread because C shares the same ThreadId + // and remove_thread would kill the active view too. Instead, + // directly remove from retained_threads and drop the handle + // so on_release → close_all_sessions fires only on A. + drop(retained_conversation_a); + panel.update(&mut cx, |panel, _cx| { + panel.retained_threads.remove(&_thread_id_a); + }); + cx.run_until_parked(); + + // The key assertion: sending messages on the ACTIVE view (C) + // must succeed. If the session was disassociated by A's cleanup, + // this will fail with "Session not found". + send_message(&panel, &mut cx); + send_message(&panel, &mut cx); + + let missing = connection.missing_prompt_sessions.lock().clone(); + assert!( + missing.is_empty(), + "session should not be disassociated after retained thread reset race, \ + got missing prompt sessions: {:?}", + missing + ); + + panel.read_with(&cx, |panel, cx| { + let active_view = panel + .active_conversation_view() + .expect("conversation should remain open"); + let connected = active_view + .read(cx) + .as_connected() + .expect("conversation should be connected"); + assert!( + !connected.has_thread_error(cx), + "conversation should not have a thread error" + ); + }); + } } diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index 4607f49190f1517180a08f4816df88ebd6d05662..bb19274711b5e654cab775c32bd6766b5d84b1f5 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -433,7 +433,7 @@ pub struct ConversationView { thread_store: Option>, prompt_store: Option>, pub(crate) thread_id: ThreadId, - root_session_id: Option, + pub(crate) root_session_id: Option, server_state: ServerState, focus_handle: FocusHandle, notifications: Vec>, @@ -460,13 +460,7 @@ impl ConversationView { &'a self, cx: &'a App, ) -> Option<(acp::SessionId, acp::ToolCallId, &'a PermissionOptions)> { - let session_id = self - .active_thread()? - .read(cx) - .thread - .read(cx) - .session_id() - .clone(); + let session_id = self.active_thread()?.read(cx).session_id.clone(); self.as_connected()? .conversation .read(cx) @@ -474,7 +468,7 @@ impl ConversationView { } pub fn root_thread_has_pending_tool_call(&self, cx: &App) -> bool { - let Some(root_thread) = self.root_thread(cx) else { + let Some(root_thread) = self.root_thread_view() else { return false; }; let root_session_id = root_thread.read(cx).thread.read(cx).session_id().clone(); @@ -487,47 +481,18 @@ impl ConversationView { }) } - pub fn root_thread(&self, cx: &App) -> Option> { - match &self.server_state { - ServerState::Connected(connected) => { - let mut current = connected.active_view()?; - while let Some(parent_session_id) = - current.read(cx).thread.read(cx).parent_session_id() - { - if let Some(parent) = connected.threads.get(parent_session_id) { - current = parent; - } else { - break; - } - } - Some(current.clone()) - } - _ => None, - } - } - - pub(crate) fn root_acp_thread(&self, cx: &App) -> Option> { - let connected = self.as_connected()?; - let root_session_id = self.root_session_id.as_ref()?; - connected - .conversation - .read(cx) - .threads - .get(root_session_id) - .cloned() + pub(crate) fn root_thread(&self, cx: &App) -> Option> { + self.root_thread_view() + .map(|view| view.read(cx).thread.clone()) } - pub fn root_thread_view(&self, cx: &App) -> Option> { + pub fn root_thread_view(&self) -> Option> { self.root_session_id .as_ref() - .and_then(|sid| self.thread_view(sid, cx)) + .and_then(|id| self.thread_view(id)) } - pub fn thread_view( - &self, - session_id: &acp::SessionId, - _cx: &App, - ) -> Option> { + pub fn thread_view(&self, session_id: &acp::SessionId) -> Option> { let connected = self.as_connected()?; connected.threads.get(session_id).cloned() } @@ -703,7 +668,7 @@ impl ConversationView { thread_store, prompt_store, thread_id, - root_session_id: None, + root_session_id: resume_session_id.clone(), server_state: Self::initial_state( agent.clone(), connection_store, @@ -737,7 +702,7 @@ impl ConversationView { fn reset(&mut self, window: &mut Window, cx: &mut Context) { let (resume_session_id, cwd, title) = self - .active_thread() + .root_thread_view() .map(|thread_view| { let tv = thread_view.read(cx); let thread = tv.thread.read(cx); @@ -764,7 +729,7 @@ impl ConversationView { ); self.set_server_state(state, cx); - if let Some(view) = self.active_thread() { + if let Some(view) = self.root_thread_view() { view.update(cx, |this, cx| { this.message_editor.update(cx, |editor, cx| { editor.set_session_capabilities(this.session_capabilities.clone(), cx); @@ -805,7 +770,7 @@ impl ConversationView { let connection_entry_subscription = cx.subscribe(&connection_entry, |this, _entry, event, cx| match event { AgentConnectionEntryEvent::NewVersionAvailable(version) => { - if let Some(thread) = this.active_thread() { + if let Some(thread) = this.root_thread_view() { thread.update(cx, |thread, cx| { thread.new_server_version_available = Some(version.clone()); cx.notify(); @@ -1259,7 +1224,7 @@ impl ConversationView { } fn handle_load_error(&mut self, err: LoadError, window: &mut Window, cx: &mut Context) { - if let Some(view) = self.active_thread() { + if let Some(view) = self.root_thread_view() { if view .read(cx) .message_editor @@ -1292,7 +1257,7 @@ impl ConversationView { }; if should_retry { - if let Some(active) = self.active_thread() { + if let Some(active) = self.root_thread_view() { active.update(cx, |active, cx| { active.clear_thread_error(cx); }); @@ -1345,14 +1310,6 @@ impl ConversationView { matches!(self.server_state, ServerState::Loading { .. }) } - fn update_turn_tokens(&mut self, cx: &mut Context) { - if let Some(active) = self.active_thread() { - active.update(cx, |active, cx| { - active.update_turn_tokens(cx); - }); - } - } - fn send_queued_message_at_index( &mut self, index: usize, @@ -1360,7 +1317,7 @@ impl ConversationView { window: &mut Window, cx: &mut Context, ) { - if let Some(active) = self.active_thread() { + if let Some(active) = self.root_thread_view() { active.update(cx, |active, cx| { active.send_queued_message_at_index(index, is_send_now, window, cx); }); @@ -1375,7 +1332,7 @@ impl ConversationView { window: &mut Window, cx: &mut Context, ) { - if let Some(active) = self.active_thread() { + if let Some(active) = self.root_thread_view() { active.update(cx, |active, cx| { active.move_queued_message_to_main_editor( index, @@ -1410,7 +1367,7 @@ impl ConversationView { AcpThreadEvent::NewEntry => { let len = thread.read(cx).entries().len(); let index = len - 1; - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { let entry_view_state = active.read(cx).entry_view_state.clone(); let list_state = active.read(cx).list_state.clone(); entry_view_state.update(cx, |view_state, cx| { @@ -1428,7 +1385,7 @@ impl ConversationView { } } AcpThreadEvent::EntryUpdated(index) => { - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { let entry_view_state = active.read(cx).entry_view_state.clone(); let list_state = active.read(cx).list_state.clone(); entry_view_state.update(cx, |view_state, cx| { @@ -1441,7 +1398,7 @@ impl ConversationView { } } AcpThreadEvent::EntriesRemoved(range) => { - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { let entry_view_state = active.read(cx).entry_view_state.clone(); let list_state = active.read(cx).list_state.clone(); entry_view_state.update(cx, |view_state, _cx| view_state.remove(range.clone())); @@ -1459,14 +1416,14 @@ impl ConversationView { } AcpThreadEvent::ToolAuthorizationReceived(_) => {} AcpThreadEvent::Retry(retry) => { - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { active.update(cx, |active, _cx| { active.thread_retry_status = Some(retry.clone()); }); } } AcpThreadEvent::Stopped(stop_reason) => { - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); active.update(cx, |active, cx| { @@ -1501,7 +1458,7 @@ impl ConversationView { cx, ); - let should_send_queued = if let Some(active) = self.active_thread() { + let should_send_queued = if let Some(active) = self.root_thread_view() { active.update(cx, |active, cx| { if active.skip_queue_processing_count > 0 { active.skip_queue_processing_count -= 1; @@ -1530,7 +1487,7 @@ impl ConversationView { } AcpThreadEvent::Refusal => { let error = ThreadError::Refusal; - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { active.update(cx, |active, cx| { active.handle_thread_error(error, cx); active.thread_retry_status.take(); @@ -1544,7 +1501,7 @@ impl ConversationView { } } AcpThreadEvent::Error => { - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); active.update(cx, |active, cx| { @@ -1567,7 +1524,7 @@ impl ConversationView { } } AcpThreadEvent::LoadError(error) => { - if let Some(view) = self.active_thread() { + if let Some(view) = self.root_thread_view() { if view .read(cx) .message_editor @@ -1586,7 +1543,7 @@ impl ConversationView { } AcpThreadEvent::TitleUpdated => { if let Some(title) = thread.read(cx).title() - && let Some(active_thread) = self.thread_view(&session_id, cx) + && let Some(active_thread) = self.thread_view(&session_id) { let title_editor = active_thread.read(cx).title_editor.clone(); title_editor.update(cx, |editor, cx| { @@ -1598,7 +1555,7 @@ impl ConversationView { cx.notify(); } AcpThreadEvent::PromptCapabilitiesUpdated => { - if let Some(active) = self.thread_view(&session_id, cx) { + if let Some(active) = self.thread_view(&session_id) { active.update(cx, |active, _cx| { active .session_capabilities @@ -1608,11 +1565,14 @@ impl ConversationView { } } AcpThreadEvent::TokenUsageUpdated => { - self.update_turn_tokens(cx); - self.emit_token_limit_telemetry_if_needed(thread, cx); + if let Some(active) = self.thread_view(&session_id) { + active.update(cx, |active, cx| { + active.update_turn_tokens(cx); + }); + } } AcpThreadEvent::AvailableCommandsUpdated(available_commands) => { - if let Some(thread_view) = self.thread_view(&session_id, cx) { + if let Some(thread_view) = self.thread_view(&session_id) { let has_commands = !available_commands.is_empty(); let agent_display_name = self @@ -1729,7 +1689,7 @@ impl ConversationView { { pending_auth_method.take(); } - if let Some(active) = this.active_thread() { + if let Some(active) = this.root_thread_view() { active.update(cx, |active, cx| { active.handle_thread_error(err, cx); }) @@ -1777,7 +1737,7 @@ impl ConversationView { { pending_auth_method.take(); } - if let Some(active) = this.active_thread() { + if let Some(active) = this.root_thread_view() { active.update(cx, |active, cx| active.handle_thread_error(err, cx)); } } else { @@ -1983,7 +1943,7 @@ impl ConversationView { } pub fn has_user_submitted_prompt(&self, cx: &App) -> bool { - self.active_thread().is_some_and(|active| { + self.root_thread_view().is_some_and(|active| { active .read(cx) .thread @@ -2109,59 +2069,6 @@ impl ConversationView { .into_any_element() } - fn emit_token_limit_telemetry_if_needed( - &mut self, - thread: &Entity, - cx: &mut Context, - ) { - let Some(active_thread) = self.active_thread() else { - return; - }; - - let (ratio, agent_telemetry_id, session_id) = { - let thread_data = thread.read(cx); - let Some(token_usage) = thread_data.token_usage() else { - return; - }; - ( - token_usage.ratio(), - thread_data.connection().telemetry_id(), - thread_data.session_id().clone(), - ) - }; - - let kind = match ratio { - acp_thread::TokenUsageRatio::Normal => { - active_thread.update(cx, |active, _cx| { - active.last_token_limit_telemetry = None; - }); - return; - } - acp_thread::TokenUsageRatio::Warning => "warning", - acp_thread::TokenUsageRatio::Exceeded => "exceeded", - }; - - let should_skip = active_thread - .read(cx) - .last_token_limit_telemetry - .as_ref() - .is_some_and(|last| *last >= ratio); - if should_skip { - return; - } - - active_thread.update(cx, |active, _cx| { - active.last_token_limit_telemetry = Some(ratio); - }); - - telemetry::event!( - "Agent Token Limit Warning", - agent = agent_telemetry_id, - session_id = session_id, - kind = kind, - ); - } - fn emit_load_error_telemetry(&self, error: &LoadError) { let error_kind = match error { LoadError::Unsupported { .. } => "unsupported", @@ -2268,18 +2175,20 @@ impl ConversationView { &self, cx: &App, ) -> Option> { - let acp_thread = self.active_thread()?.read(cx).thread.read(cx); - acp_thread.connection().clone().downcast() + self.root_thread(cx)? + .read(cx) + .connection() + .clone() + .downcast() } pub fn as_native_thread(&self, cx: &App) -> Option> { - let acp_thread = self.active_thread()?.read(cx).thread.read(cx); self.as_native_connection(cx)? - .thread(acp_thread.session_id(), cx) + .thread(self.root_session_id.as_ref()?, cx) } fn queued_messages_len(&self, cx: &App) -> usize { - self.active_thread() + self.root_thread_view() .map(|thread| thread.read(cx).local_queued_messages.len()) .unwrap_or_default() } @@ -2291,7 +2200,7 @@ impl ConversationView { tracked_buffers: Vec>, cx: &mut Context, ) -> bool { - match self.active_thread() { + match self.root_thread_view() { Some(thread) => thread.update(cx, |thread, _cx| { if index < thread.local_queued_messages.len() { thread.local_queued_messages[index] = QueuedMessage { @@ -2308,7 +2217,7 @@ impl ConversationView { } fn queued_message_contents(&self, cx: &App) -> Vec> { - match self.active_thread() { + match self.root_thread_view() { None => Vec::new(), Some(thread) => thread .read(cx) @@ -2320,7 +2229,7 @@ impl ConversationView { } fn save_queued_message_at_index(&mut self, index: usize, cx: &mut Context) { - let editor = match self.active_thread() { + let editor = match self.root_thread_view() { Some(thread) => thread.read(cx).queued_message_editors.get(index).cloned(), None => None, }; @@ -2451,7 +2360,7 @@ impl ConversationView { }); } - if let Some(active) = self.active_thread() { + if let Some(active) = self.root_thread_view() { active.update(cx, |active, _cx| { active.last_synced_queue_length = needed_count; }); @@ -2545,7 +2454,7 @@ impl ConversationView { return; } - let Some(root_thread) = self.root_thread(cx) else { + let Some(root_thread) = self.root_thread_view() else { return; }; let root_thread = root_thread.read(cx).thread.read(cx); @@ -2764,7 +2673,7 @@ impl ConversationView { // For ACP agents, use the agent name (e.g., "Claude Agent", "Gemini CLI") // This provides better clarity about what refused the request if self.as_native_connection(cx).is_some() { - self.active_thread() + self.root_thread_view() .and_then(|active| active.read(cx).model_selector.clone()) .and_then(|selector| selector.read(cx).active_model(cx)) .map(|model| model.name.clone()) @@ -2783,7 +2692,7 @@ impl ConversationView { pub(crate) fn reauthenticate(&mut self, window: &mut Window, cx: &mut Context) { let agent_id = self.agent.agent_id(); - if let Some(active) = self.active_thread() { + if let Some(active) = self.root_thread_view() { active.update(cx, |active, cx| active.clear_thread_error(cx)); } let this = cx.weak_entity(); @@ -3927,7 +3836,7 @@ pub(crate) mod tests { let root_session_id = conversation_view .read_with(cx, |view, cx| { - view.root_thread(cx) + view.root_thread_view() .map(|thread| thread.read(cx).thread.read(cx).session_id().clone()) }) .expect("Conversation view should have a root thread"); diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index 95d3f3599b41a5028a6bed6fa51c179eb51e767f..86f920c157a7ea0a5895f03342df73e0403b473f 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -885,10 +885,51 @@ impl ThreadView { if let Some(usage) = self.thread.read(cx).token_usage() { if let Some(tokens) = &mut self.turn_fields.turn_tokens { *tokens += usage.output_tokens; + self.emit_token_limit_telemetry_if_needed(cx); } } } + fn emit_token_limit_telemetry_if_needed(&mut self, cx: &App) { + let (ratio, agent_telemetry_id, session_id) = { + let thread_data = self.thread.read(cx); + let Some(token_usage) = thread_data.token_usage() else { + return; + }; + ( + token_usage.ratio(), + thread_data.connection().telemetry_id(), + thread_data.session_id().clone(), + ) + }; + + let kind = match ratio { + acp_thread::TokenUsageRatio::Normal => { + self.last_token_limit_telemetry = None; + return; + } + acp_thread::TokenUsageRatio::Warning => "warning", + acp_thread::TokenUsageRatio::Exceeded => "exceeded", + }; + + let should_skip = self + .last_token_limit_telemetry + .as_ref() + .is_some_and(|last| *last >= ratio); + if should_skip { + return; + } + + self.last_token_limit_telemetry = Some(ratio); + + telemetry::event!( + "Agent Token Limit Warning", + agent = agent_telemetry_id, + session_id = session_id, + kind = kind, + ); + } + // sending fn clear_external_source_prompt_warning(&mut self, cx: &mut Context) { @@ -3096,7 +3137,7 @@ impl ThreadView { self.server_view .upgrade() - .and_then(|sv| sv.read(cx).thread_view(parent_session_id, cx)) + .and_then(|sv| sv.read(cx).thread_view(parent_session_id)) .is_some_and(|parent_view| { parent_view .read(cx) diff --git a/crates/agent_ui/src/thread_metadata_store.rs b/crates/agent_ui/src/thread_metadata_store.rs index 154e71eb0dff5880c30302de7aa46a665e54ff1f..c36290fe32bc3905973f66aa8ff8f3f1ded3a5b9 100644 --- a/crates/agent_ui/src/thread_metadata_store.rs +++ b/crates/agent_ui/src/thread_metadata_store.rs @@ -1135,12 +1135,12 @@ impl ThreadMetadataStore { ) { let view = conversation_view.read(cx); let thread_id = view.thread_id; - let Some(thread) = view.root_acp_thread(cx) else { + let Some(thread) = view.root_thread(cx) else { return; }; let thread_ref = thread.read(cx); - if thread_ref.is_draft_thread() { + if thread_ref.is_draft_thread() || thread_ref.project().read(cx).is_via_collab() { return; } @@ -3728,4 +3728,144 @@ mod tests { ); }); } + + #[gpui::test] + async fn test_collab_guest_threads_not_saved_to_metadata_store(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [Path::new("/project-a")], cx).await; + + let (panel, mut vcx) = setup_panel_with_project(project.clone(), cx); + crate::test_support::open_thread_with_connection( + &panel, + StubAgentConnection::new(), + &mut vcx, + ); + let thread = panel.read_with(&vcx, |panel, cx| panel.active_agent_thread(cx).unwrap()); + let thread_id = crate::test_support::active_thread_id(&panel, &vcx); + thread.update_in(&mut vcx, |thread, _window, cx| { + thread.push_user_content_block(None, "hello".into(), cx); + thread.set_title("Thread".into(), cx).detach(); + }); + vcx.run_until_parked(); + + // Confirm the thread is in the store while the project is local. + cx.update(|cx| { + let store = ThreadMetadataStore::global(cx); + assert!( + store.read(cx).entry(thread_id).is_some(), + "thread must be in the store while the project is local" + ); + }); + + cx.update(|cx| { + let store = ThreadMetadataStore::global(cx); + store.update(cx, |store, cx| { + store.delete(thread_id, cx); + }); + }); + project.update(cx, |project, _cx| { + project.mark_as_collab_for_testing(); + }); + + thread.update_in(&mut vcx, |thread, _window, cx| { + thread.push_user_content_block(None, "more content".into(), cx); + }); + vcx.run_until_parked(); + + cx.update(|cx| { + let store = ThreadMetadataStore::global(cx); + assert!( + store.read(cx).entry(thread_id).is_none(), + "threads must not be persisted while the project is a collab guest session" + ); + }); + } + + // When a worktree is added to a collab project, update_thread_work_dirs + // fires with the new worktree paths. Without an is_via_collab() guard it + // overwrites the stored paths of any retained or active local threads with + // the new (expanded) path set, corrupting metadata that belonged to the + // guest's own local project. + #[gpui::test] + async fn test_collab_guest_retained_thread_paths_not_overwritten_on_worktree_change( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project-a", serde_json::json!({})).await; + fs.insert_tree("/project-b", serde_json::json!({})).await; + let project = Project::test(fs, [Path::new("/project-a")], cx).await; + + let (panel, mut vcx) = setup_panel_with_project(project.clone(), cx); + + // Open thread A and give it content so its metadata is saved with /project-a. + crate::test_support::open_thread_with_connection( + &panel, + StubAgentConnection::new(), + &mut vcx, + ); + let thread_a_id = crate::test_support::active_thread_id(&panel, &vcx); + let thread_a = panel.read_with(&vcx, |panel, cx| panel.active_agent_thread(cx).unwrap()); + thread_a.update_in(&mut vcx, |thread, _window, cx| { + thread.push_user_content_block(None, "hello".into(), cx); + thread.set_title("Thread A".into(), cx).detach(); + }); + vcx.run_until_parked(); + + cx.update(|cx| { + let store = ThreadMetadataStore::global(cx); + let entry = store.read(cx).entry(thread_a_id).unwrap(); + assert_eq!( + entry.folder_paths().paths(), + &[std::path::PathBuf::from("/project-a")], + "thread A must be saved with /project-a before collab" + ); + }); + + // Open thread B, making thread A a retained thread in the panel. + crate::test_support::open_thread_with_connection( + &panel, + StubAgentConnection::new(), + &mut vcx, + ); + vcx.run_until_parked(); + + // Transition the project into collab mode (simulates joining as a guest). + project.update(cx, |project, _cx| { + project.mark_as_collab_for_testing(); + }); + + // Add a second worktree. For a real collab guest this would be one of + // the host's worktrees arriving via the collab protocol, but here we + // use a local path because the test infrastructure cannot easily produce + // a remote worktree with a fully-scanned root entry. + // + // This fires WorktreeAdded → update_thread_work_dirs. Without an + // is_via_collab() guard that call overwrites the stored paths of + // retained thread A from {/project-a} to {/project-a, /project-b}, + // polluting its metadata with a path it never belonged to. + project + .update(cx, |project, cx| { + project.find_or_create_worktree(Path::new("/project-b"), true, cx) + }) + .await + .unwrap(); + vcx.run_until_parked(); + + cx.update(|cx| { + let store = ThreadMetadataStore::global(cx); + let entry = store + .read(cx) + .entry(thread_a_id) + .expect("thread A must still exist in the store"); + assert_eq!( + entry.folder_paths().paths(), + &[std::path::PathBuf::from("/project-a")], + "retained thread A's stored path must not be updated while the project is via collab" + ); + }); + } } diff --git a/crates/buffer_diff/src/buffer_diff.rs b/crates/buffer_diff/src/buffer_diff.rs index c168bd2956e0687eca5e5adeb16edbe70e9edd54..56c3fe9f51b1ea5a21f3b636ea9f8ae59e20ed3b 100644 --- a/crates/buffer_diff/src/buffer_diff.rs +++ b/crates/buffer_diff/src/buffer_diff.rs @@ -1515,11 +1515,17 @@ pub struct DiffChanged { #[derive(Clone, Debug)] pub enum BufferDiffEvent { + BaseTextChanged, DiffChanged(DiffChanged), LanguageChanged, HunksStagedOrUnstaged(Option), } +struct SetSnapshotResult { + change: DiffChanged, + base_text_changed: bool, +} + impl EventEmitter for BufferDiff {} impl BufferDiff { @@ -1784,7 +1790,7 @@ impl BufferDiff { secondary_diff_change: Option>, clear_pending_hunks: bool, cx: &mut Context, - ) -> impl Future + use<> { + ) -> impl Future + use<> { log::debug!("set snapshot with secondary {secondary_diff_change:?}"); let old_snapshot = self.snapshot(cx); @@ -1904,10 +1910,13 @@ impl BufferDiff { if let Some(parsing_idle) = parsing_idle { parsing_idle.await; } - DiffChanged { - changed_range, - base_text_changed_range, - extended_range, + SetSnapshotResult { + change: DiffChanged { + changed_range, + base_text_changed_range, + extended_range, + }, + base_text_changed, } } } @@ -1938,12 +1947,15 @@ impl BufferDiff { ); cx.spawn(async move |this, cx| { - let change = fut.await; + let result = fut.await; this.update(cx, |_, cx| { - cx.emit(BufferDiffEvent::DiffChanged(change.clone())); + if result.base_text_changed { + cx.emit(BufferDiffEvent::BaseTextChanged); + } + cx.emit(BufferDiffEvent::DiffChanged(result.change.clone())); }) .ok(); - change.changed_range + result.change.changed_range }) } @@ -2019,8 +2031,11 @@ impl BufferDiff { let fg_executor = cx.foreground_executor().clone(); let snapshot = fg_executor.block_on(fut); let fut = self.set_snapshot_with_secondary_inner(snapshot, buffer, None, false, cx); - let change = fg_executor.block_on(fut); - cx.emit(BufferDiffEvent::DiffChanged(change)); + let result = fg_executor.block_on(fut); + if result.base_text_changed { + cx.emit(BufferDiffEvent::BaseTextChanged); + } + cx.emit(BufferDiffEvent::DiffChanged(result.change)); } pub fn base_text_buffer(&self) -> &Entity { diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index 67869f770b81f315680388165111bbc1a2e0f111..5742c9d20ce2dbdb9b1effeb31945a8df24f914f 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -738,7 +738,8 @@ pub fn find_boundary_point( && is_boundary(prev_ch, ch) { if return_point_before_boundary { - return map.clip_point(prev_offset.to_display_point(map), Bias::Right); + let point = prev_offset.to_point(map.buffer_snapshot()); + return map.clip_point(map.point_to_display_point(point, Bias::Right), Bias::Right); } else { break; } @@ -747,7 +748,8 @@ pub fn find_boundary_point( offset += ch.len_utf8(); prev_ch = Some(ch); } - map.clip_point(offset.to_display_point(map), Bias::Right) + let point = offset.to_point(map.buffer_snapshot()); + map.clip_point(map.point_to_display_point(point, Bias::Right), Bias::Right) } pub fn find_preceding_boundary_trail( @@ -836,13 +838,15 @@ pub fn find_boundary_trail( prev_ch = Some(ch); } - let trail = trail_offset - .map(|trail_offset| map.clip_point(trail_offset.to_display_point(map), Bias::Right)); + let trail = trail_offset.map(|trail_offset| { + let point = trail_offset.to_point(map.buffer_snapshot()); + map.clip_point(map.point_to_display_point(point, Bias::Right), Bias::Right) + }); - ( - trail, - map.clip_point(offset.to_display_point(map), Bias::Right), - ) + (trail, { + let point = offset.to_point(map.buffer_snapshot()); + map.clip_point(map.point_to_display_point(point, Bias::Right), Bias::Right) + }) } pub fn find_boundary( @@ -1406,6 +1410,96 @@ mod tests { }); } + #[gpui::test] + fn test_word_movement_over_folds(cx: &mut gpui::App) { + use crate::display_map::Crease; + + init_test(cx); + + // Simulate a mention: `hello [@file.txt](file:///path) world` + // The fold covers `[@file.txt](file:///path)` and is replaced by "⋯". + // Display text: `hello ⋯ world` + let buffer_text = "hello [@file.txt](file:///path) world"; + let buffer = MultiBuffer::build_simple(buffer_text, cx); + let font = font("Helvetica"); + let display_map = cx.new(|cx| { + DisplayMap::new( + buffer, + font, + px(14.0), + None, + 0, + 1, + FoldPlaceholder::test(), + DiagnosticSeverity::Warning, + cx, + ) + }); + display_map.update(cx, |map, cx| { + // Fold the `[@file.txt](file:///path)` range (bytes 6..31) + map.fold( + vec![Crease::simple( + Point::new(0, 6)..Point::new(0, 31), + FoldPlaceholder::test(), + )], + cx, + ); + }); + let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx)); + + // "hello " (6 bytes) + "⋯" (3 bytes) + " world" (6 bytes) = "hello ⋯ world" + assert_eq!(snapshot.text(), "hello ⋯ world"); + + // Ctrl+Right from before fold ("hello |⋯ world") should skip past the fold. + // Cursor at column 6 = start of fold. + let before_fold = DisplayPoint::new(DisplayRow(0), 6); + let after_fold = next_word_end(&snapshot, before_fold); + // Should land past the fold, not get stuck at fold start. + assert!( + after_fold > before_fold, + "next_word_end should move past the fold: got {:?}, started at {:?}", + after_fold, + before_fold + ); + + // Ctrl+Right from "hello" should jump past "hello" to the fold or past it. + let at_start = DisplayPoint::new(DisplayRow(0), 0); + let after_hello = next_word_end(&snapshot, at_start); + assert_eq!( + after_hello, + DisplayPoint::new(DisplayRow(0), 5), + "next_word_end from start should land at end of 'hello'" + ); + + // Ctrl+Left from after fold should move to before the fold. + // "⋯" ends at column 9. " world" starts at 9. Column 15 = end of "world". + let after_world = DisplayPoint::new(DisplayRow(0), 15); + let before_world = previous_word_start(&snapshot, after_world); + assert_eq!( + before_world, + DisplayPoint::new(DisplayRow(0), 10), + "previous_word_start from end should land at start of 'world'" + ); + + // Ctrl+Left from start of "world" should land before fold. + let start_of_world = DisplayPoint::new(DisplayRow(0), 10); + let landed = previous_word_start(&snapshot, start_of_world); + // The fold acts as a word, so we should land at the fold start (column 6). + assert_eq!( + landed, + DisplayPoint::new(DisplayRow(0), 6), + "previous_word_start from 'world' should land at fold start" + ); + + // End key from start should go to end of line (column 15), not fold start. + let end_pos = line_end(&snapshot, at_start, false); + assert_eq!( + end_pos, + DisplayPoint::new(DisplayRow(0), 15), + "line_end should go to actual end of line, not fold start" + ); + } + fn init_test(cx: &mut gpui::App) { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 515a518e530e79eb7bdf2a3074e6bf12a5824027..cdc134a2a956b8724b66e2d567b45be252385067 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -2075,6 +2075,18 @@ impl Project { project } + /// Transitions a local test project into the `Collab` client state so that + /// `is_via_collab()` returns `true`. Use only in tests. + #[cfg(any(test, feature = "test-support"))] + pub fn mark_as_collab_for_testing(&mut self) { + self.client_state = ProjectClientState::Collab { + sharing_has_stopped: false, + capability: Capability::ReadWrite, + remote_id: 0, + replica_id: clock::ReplicaId::new(1), + }; + } + #[cfg(any(test, feature = "test-support"))] pub fn add_test_remote_worktree( &mut self, diff --git a/crates/sidebar/src/sidebar.rs b/crates/sidebar/src/sidebar.rs index 2418e2bc1d6554434bacf7d7004143593a7645f9..381da123c80ff30987ab9bd2f01207b06184d8f2 100644 --- a/crates/sidebar/src/sidebar.rs +++ b/crates/sidebar/src/sidebar.rs @@ -541,6 +541,9 @@ impl Sidebar { cx: &mut Context, ) { let project = workspace.read(cx).project().clone(); + if project.read(cx).is_via_collab() { + return; + } cx.subscribe_in( &project, @@ -607,6 +610,10 @@ impl Sidebar { old_paths: &WorktreePaths, cx: &mut Context, ) { + if project.read(cx).is_via_collab() { + return; + } + let new_paths = project.read(cx).worktree_paths(cx); let old_folder_paths = old_paths.folder_path_list().clone(); @@ -2234,7 +2241,6 @@ impl Sidebar { let mut existing_panel = None; workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, _cx| panel.begin_loading_thread()); existing_panel = Some(panel); } }); @@ -2262,7 +2268,6 @@ impl Sidebar { workspace.add_panel(panel.clone(), window, cx); panel.clone() }); - panel.update(cx, |panel, _cx| panel.begin_loading_thread()); load_thread(panel, &metadata, focus, window, cx); if focus { workspace.focus_panel::(window, cx); @@ -4895,7 +4900,7 @@ fn all_thread_infos_for_workspace( .read(cx) .root_thread_has_pending_tool_call(cx); let conversation_thread_id = conversation_view.read(cx).parent_id(); - let thread_view = conversation_view.read(cx).root_thread(cx)?; + let thread_view = conversation_view.read(cx).root_thread_view()?; let thread_view_ref = thread_view.read(cx); let thread = thread_view_ref.thread.read(cx); @@ -5148,7 +5153,7 @@ fn dump_single_workspace(workspace: &Workspace, output: &mut String, cx: &gpui:: ) .ok(); for (session_id, conversation_view) in background_threads { - if let Some(thread_view) = conversation_view.read(cx).root_thread(cx) { + if let Some(thread_view) = conversation_view.read(cx).root_thread_view() { let thread = thread_view.read(cx).thread.read(cx); let title = thread.title().unwrap_or_else(|| "(untitled)".into()); let status = match thread.status() { diff --git a/crates/sidebar/src/sidebar_tests.rs b/crates/sidebar/src/sidebar_tests.rs index 2fa51f90fe6b3602e72d9fdc3efc6588cd7087e8..4ee5cad3bf9df2a14e45ff2a5ac227709e396422 100644 --- a/crates/sidebar/src/sidebar_tests.rs +++ b/crates/sidebar/src/sidebar_tests.rs @@ -1535,7 +1535,7 @@ async fn test_subagent_permission_request_marks_parent_sidebar_thread_waiting( let subagent_thread = panel.read_with(cx, |panel, cx| { panel .active_conversation_view() - .and_then(|conversation| conversation.read(cx).thread_view(&subagent_session_id, cx)) + .and_then(|conversation| conversation.read(cx).thread_view(&subagent_session_id)) .map(|thread_view| thread_view.read(cx).thread.clone()) .expect("Expected subagent thread to be loaded into the conversation") }); @@ -10490,3 +10490,74 @@ fn test_worktree_info_missing_branch_returns_none() { assert_eq!(infos[0].branch_name, None); assert_eq!(infos[0].name, SharedString::from("myapp")); } + +#[gpui::test] +async fn test_collab_guest_move_thread_paths_is_noop(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project-a", serde_json::json!({ "src": {} })) + .await; + fs.insert_tree("/project-b", serde_json::json!({ "src": {} })) + .await; + cx.update(|cx| ::set_global(fs.clone(), cx)); + let project = project::Project::test(fs, ["/project-a".as_ref()], cx).await; + + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + // Set up the sidebar while the project is local. This registers the + // WorktreePathsChanged subscription for the project. + let _sidebar = setup_sidebar(&multi_workspace, cx); + + let session_id = acp::SessionId::new(Arc::from("test-thread")); + save_named_thread_metadata("test-thread", "My Thread", &project, cx).await; + + let thread_id = cx.update(|_window, cx| { + ThreadMetadataStore::global(cx) + .read(cx) + .entry_by_session(&session_id) + .map(|e| e.thread_id) + .expect("thread must be in the store") + }); + + cx.update(|_window, cx| { + let store = ThreadMetadataStore::global(cx); + let entry = store.read(cx).entry(thread_id).unwrap(); + assert_eq!( + entry.folder_paths().paths(), + &[PathBuf::from("/project-a")], + "thread must be saved with /project-a before collab" + ); + }); + + // Transition the project into collab mode. The sidebar's subscription is + // still active from when the project was local. + project.update(cx, |project, _cx| { + project.mark_as_collab_for_testing(); + }); + + // Adding a worktree fires WorktreePathsChanged with old_paths = {/project-a}. + // The sidebar's subscription is still active, so move_thread_paths is called. + // Without the is_via_collab() guard inside move_thread_paths, this would + // update the stored thread paths from {/project-a} to {/project-a, /project-b}. + project + .update(cx, |project, cx| { + project.find_or_create_worktree("/project-b", true, cx) + }) + .await + .expect("should add worktree"); + cx.run_until_parked(); + + cx.update(|_window, cx| { + let store = ThreadMetadataStore::global(cx); + let entry = store + .read(cx) + .entry(thread_id) + .expect("thread must still exist"); + assert_eq!( + entry.folder_paths().paths(), + &[PathBuf::from("/project-a")], + "thread path must not change when project is via collab" + ); + }); +} diff --git a/crates/ui/src/components/ai/thread_item.rs b/crates/ui/src/components/ai/thread_item.rs index 159e5bb48422a5102ee5ad3e9edde54007be62ec..e579a9bdb8713007736ede496444edb2d9125649 100644 --- a/crates/ui/src/components/ai/thread_item.rs +++ b/crates/ui/src/components/ai/thread_item.rs @@ -392,13 +392,13 @@ impl RenderOnce for ThreadItem { let has_timestamp = !self.timestamp.is_empty(); let timestamp = self.timestamp; - let visible_worktree_count = self + let linked_worktree_count = self .worktrees .iter() - .filter(|wt| !(wt.kind == WorktreeKind::Main && wt.branch_name.is_none())) + .filter(|wt| wt.kind == WorktreeKind::Linked) .count(); - let worktree_tooltip_title = match (self.is_remote, visible_worktree_count > 1) { + let worktree_tooltip_title = match (self.is_remote, linked_worktree_count > 1) { (true, true) => "Thread Running in Remote Git Worktrees", (true, false) => "Thread Running in a Remote Git Worktree", (false, true) => "Thread Running in Local Git Worktrees", @@ -410,44 +410,9 @@ impl RenderOnce for ThreadItem { let slash_color = Color::Custom(cx.theme().colors().text_muted.opacity(0.4)); for wt in self.worktrees { - match (wt.kind, wt.branch_name) { - (WorktreeKind::Main, None) => continue, - (WorktreeKind::Main, Some(branch)) => { - let chip_index = worktree_labels.len(); - let tooltip_title = worktree_tooltip_title; - let full_path = wt.full_path.clone(); - - worktree_labels.push( - h_flex() - .id(format!("{}-worktree-{chip_index}", self.id.clone())) - .min_w_0() - .when(visible_worktree_count > 1, |this| { - this.child( - Label::new(wt.name) - .size(LabelSize::Small) - .color(Color::Muted) - .truncate(), - ) - .child( - Label::new("/") - .size(LabelSize::Small) - .color(slash_color) - .flex_shrink_0(), - ) - }) - .child( - Label::new(branch) - .size(LabelSize::Small) - .color(Color::Muted) - .truncate(), - ) - .tooltip(move |_, cx| { - Tooltip::with_meta(tooltip_title, None, full_path.clone(), cx) - }) - .into_any_element(), - ); - } - (WorktreeKind::Linked, branch) => { + match wt.kind { + WorktreeKind::Main => continue, + WorktreeKind::Linked => { let chip_index = worktree_labels.len(); let tooltip_title = worktree_tooltip_title; let full_path = wt.full_path.clone(); @@ -477,7 +442,7 @@ impl RenderOnce for ThreadItem { .color(Color::Muted), ) .child(label) - .when_some(branch, |this, branch| { + .when_some(wt.branch_name, |this, branch| { this.child( Label::new("/") .size(LabelSize::Small) @@ -789,7 +754,7 @@ impl Component for ThreadItem { .into_any_element(), ), single_example( - "Main Branch + Changes + Timestamp", + "Main Worktree (hidden) + Changes + Timestamp", container() .child( ThreadItem::new("ti-5e", "Main worktree branch with diff stats") diff --git a/crates/ui/src/components/notification/announcement_toast.rs b/crates/ui/src/components/notification/announcement_toast.rs index 920f97f9959d1f6c2427ccd7645d15e859893aa4..215d8b9aedfa4584d97ab72b8d816c7a0e516fbc 100644 --- a/crates/ui/src/components/notification/announcement_toast.rs +++ b/crates/ui/src/components/notification/announcement_toast.rs @@ -101,6 +101,8 @@ impl RenderOnce for AnnouncementToast { let illustration = self.illustration; v_flex() + .id("announcement-toast") + .occlude() .relative() .w_full() .elevation_3(cx) diff --git a/crates/workspace/src/welcome.rs b/crates/workspace/src/welcome.rs index 5bfed0ceed93a4b1216b5117486d6ef5f1d5571b..a13ec56b2e07a81667fe096a8780393d93bf6f48 100644 --- a/crates/workspace/src/welcome.rs +++ b/crates/workspace/src/welcome.rs @@ -399,18 +399,11 @@ impl WelcomePage { location: &SerializedWorkspaceLocation, paths: &PathList, ) -> impl IntoElement { + let name = project_name(paths); + let (icon, title) = match location { - SerializedWorkspaceLocation::Local => { - let path = paths.paths().first().map(|p| p.as_path()); - let name = path - .and_then(|p| p.file_name()) - .map(|n| n.to_string_lossy().to_string()) - .unwrap_or_else(|| "Untitled".to_string()); - (IconName::Folder, name) - } - SerializedWorkspaceLocation::Remote(_) => { - (IconName::Server, "Remote Project".to_string()) - } + SerializedWorkspaceLocation::Local => (IconName::Folder, name), + SerializedWorkspaceLocation::Remote(_) => (IconName::Server, name), }; SectionButton::new( @@ -661,3 +654,48 @@ mod persistence { } } } + +fn project_name(paths: &PathList) -> String { + let joined = paths + .paths() + .iter() + .filter_map(|p| p.file_name().map(|n| n.to_string_lossy().to_string())) + .collect::>() + .join(", "); + if joined.is_empty() { + "Untitled".to_string() + } else { + joined + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_project_name_empty() { + let paths = PathList::new::<&str>(&[]); + assert_eq!(project_name(&paths), "Untitled"); + } + + #[test] + fn test_project_name_single() { + let paths = PathList::new(&["/home/user/my-project"]); + assert_eq!(project_name(&paths), "my-project"); + } + + #[test] + fn test_project_name_multiple() { + // PathList sorts lexicographically, so filenames appear in alpha order + let paths = PathList::new(&["/home/user/zed", "/home/user/api"]); + assert_eq!(project_name(&paths), "api, zed"); + } + + #[test] + fn test_project_name_root_path_filtered() { + // A bare root "/" has no file_name(), falls back to "Untitled" + let paths = PathList::new(&["/"]); + assert_eq!(project_name(&paths), "Untitled"); + } +} diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index feaf717a1fc8e5d2224fa161f9caca972297065d..f0ca1eff5daa10c6513714a042799e7bf337f04c 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -231,7 +231,7 @@ pub struct Open { impl Open { pub const DEFAULT: Self = Self { - create_new_window: true, + create_new_window: false, }; /// Used by `#[serde(default)]` on the `create_new_window` field so that diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index ec93fe636620da65c42e637f3760c7779f1a045a..d4c0d29ade5c4bd6496509675f9ccb3fc188eb8f 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -2939,9 +2939,7 @@ impl gpui::Render for ThreadItemBranchNameTestView { }]), ), ) - .child(section_label( - "Main worktree with branch (branch only, no icon)", - )) + .child(section_label("Main worktree with branch (nothing shown)")) .child( container().child( ThreadItem::new("ti-main-branch", "Request for Long Classic Poem") @@ -3043,7 +3041,9 @@ impl gpui::Render for ThreadItemBranchNameTestView { }]), ), ) - .child(section_label("Main branch + diff stats + timestamp")) + .child(section_label( + "Main worktree with branch + diff stats + timestamp (branch hidden)", + )) .child( container().child( ThreadItem::new("ti-main-full", "Main worktree with everything") diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 48d6501f737155ad96bfc2f0d865c5a19b186c7e..c3c02b83cbc4bbdd6d10b59af4c3fac652a9599c 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1088,11 +1088,12 @@ fn register_actions( }) .register_action({ let app_state = app_state.clone(); - move |_workspace, _: &CloseProject, window, cx| { + move |workspace, _: &CloseProject, window, cx| { let Some(window_handle) = window.window_handle().downcast::() else { return; }; let app_state = app_state.clone(); + let old_group_key = workspace.project_group_key(cx); cx.spawn_in(window, async move |this, cx| { let should_continue = this .update_in(cx, |workspace, window, cx| { @@ -1131,7 +1132,11 @@ fn register_actions( }, ) })?; - task.await + task.await?; + window_handle.update(cx, |mw, window, cx| { + mw.remove_project_group(&old_group_key, window, cx) + })?.await.log_err(); + Ok::<(), anyhow::Error>(()) } else { Ok(()) } @@ -6446,4 +6451,55 @@ mod tests { }) .unwrap(); } + + #[gpui::test] + async fn test_close_project_removes_project_group(cx: &mut TestAppContext) { + use util::path_list::PathList; + use workspace::{OpenMode, ProjectGroupKey}; + + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree(path!("/my-project"), json!({})) + .await; + + let workspace::OpenResult { window, .. } = cx + .update(|cx| { + workspace::Workspace::new_local( + vec![path!("/my-project").into()], + app_state.clone(), + None, + None, + None, + OpenMode::Activate, + cx, + ) + }) + .await + .unwrap(); + + window.update(cx, |mw, _, cx| mw.open_sidebar(cx)).unwrap(); + cx.background_executor.run_until_parked(); + + let project_key = ProjectGroupKey::new(None, PathList::new(&[path!("/my-project")])); + let keys = window + .read_with(cx, |mw, _| mw.project_group_keys()) + .unwrap(); + assert_eq!( + keys, + vec![project_key], + "project group should exist before CloseProject: {keys:?}" + ); + + cx.dispatch_action(window.into(), CloseProject); + + let keys = window + .read_with(cx, |mw, _| mw.project_group_keys()) + .unwrap(); + assert!( + keys.is_empty(), + "project group should be removed after CloseProject: {keys:?}" + ); + } } diff --git a/tooling/xtask/src/tasks/workflows/run_tests.rs b/tooling/xtask/src/tasks/workflows/run_tests.rs index f51b21b961ddbeabf30c5e757bdf6815833ab3ca..18df32d6ef55e1af5486e4a734cd1278aca7abe1 100644 --- a/tooling/xtask/src/tasks/workflows/run_tests.rs +++ b/tooling/xtask/src/tasks/workflows/run_tests.rs @@ -454,13 +454,14 @@ fn check_wasm() -> NamedJob { fn cargo_check_wasm() -> Step { named::bash(concat!( - "cargo +nightly -Zbuild-std=std,panic_abort ", + "cargo -Zbuild-std=std,panic_abort ", "check --target wasm32-unknown-unknown -p gpui_platform", )) .add_env(( "CARGO_TARGET_WASM32_UNKNOWN_UNKNOWN_RUSTFLAGS", "-C target-feature=+atomics,+bulk-memory,+mutable-globals", )) + .add_env(("RUSTC_BOOTSTRAP", "1")) } named::job(