diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index d6df01ba1e18f98e0091cf6169cf6c7b7ad3cde6..d71ac3c73e8f7e3c298f397e612119c3586f624e 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1282,16 +1282,37 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) -> Entity { - if let Some(draft) = &self.draft_thread { - return draft.clone(); - } - let agent = if self.project.read(cx).is_via_collab() { + let desired_agent = if self.project.read(cx).is_via_collab() { Agent::NativeAgent } else { self.selected_agent.clone() }; - let thread = - self.create_agent_thread(agent, None, None, None, None, "agent_panel", window, cx); + if let Some(draft) = &self.draft_thread { + let agent_matches = *draft.read(cx).agent_key() == desired_agent; + let has_editor_content = draft.read(cx).active_thread().is_some_and(|tv| { + !tv.read(cx) + .message_editor + .read(cx) + .text(cx) + .trim() + .is_empty() + }); + if agent_matches || has_editor_content { + return draft.clone(); + } + self.draft_thread = None; + self._draft_editor_observation = None; + } + let thread = self.create_agent_thread( + desired_agent, + None, + None, + None, + None, + "agent_panel", + window, + cx, + ); self.draft_thread = Some(thread.conversation_view.clone()); self.observe_draft_editor(&thread.conversation_view, cx); thread.conversation_view @@ -7104,6 +7125,95 @@ mod tests { }); } + #[gpui::test] + async fn test_draft_replaced_when_selected_agent_changes(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| { + agent::ThreadStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + ::set_global(fs.clone(), cx); + }); + + let project = Project::test(fs.clone(), [], cx).await; + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _cx| { + multi_workspace.workspace().clone() + }) + .unwrap(); + + workspace.update(cx, |workspace, _cx| { + workspace.set_random_database_id(); + }); + + let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx); + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let panel = cx.new(|cx| AgentPanel::new(workspace, None, window, cx)); + workspace.add_panel(panel.clone(), window, cx); + panel + }); + + // Create a draft with the default NativeAgent. + panel.update_in(cx, |panel, window, cx| { + panel.activate_draft(true, window, cx); + }); + + let first_draft_id = panel.read_with(cx, |panel, cx| { + assert!(panel.draft_thread.is_some()); + assert_eq!(panel.selected_agent, Agent::NativeAgent); + let draft = panel.draft_thread.as_ref().unwrap(); + assert_eq!(*draft.read(cx).agent_key(), Agent::NativeAgent); + draft.entity_id() + }); + + // Switch selected_agent to a custom agent, then activate_draft again. + // The stale NativeAgent draft should be replaced. + let custom_agent = Agent::Custom { + id: "my-custom-agent".into(), + }; + panel.update_in(cx, |panel, window, cx| { + panel.selected_agent = custom_agent.clone(); + panel.activate_draft(true, window, cx); + }); + + panel.read_with(cx, |panel, cx| { + let draft = panel.draft_thread.as_ref().expect("draft should exist"); + assert_ne!( + draft.entity_id(), + first_draft_id, + "a new draft should have been created" + ); + assert_eq!( + *draft.read(cx).agent_key(), + custom_agent, + "the new draft should use the custom agent" + ); + }); + + // Calling activate_draft again with the same agent should return the + // cached draft (no replacement). + let second_draft_id = panel.read_with(cx, |panel, _cx| { + panel.draft_thread.as_ref().unwrap().entity_id() + }); + + panel.update_in(cx, |panel, window, cx| { + panel.activate_draft(true, window, cx); + }); + + panel.read_with(cx, |panel, _cx| { + assert_eq!( + panel.draft_thread.as_ref().unwrap().entity_id(), + second_draft_id, + "draft should be reused when the agent has not changed" + ); + }); + } + #[gpui::test] async fn test_rollback_all_succeed_returns_ok(cx: &mut TestAppContext) { init_test(cx);