From 42ba961075b16aaf35d48631f0ce4e1a4196d983 Mon Sep 17 00:00:00 2001 From: Eric Holk Date: Mon, 2 Mar 2026 16:57:15 -0800 Subject: [PATCH] Persist unsent draft prompt across Zed restarts (#49541) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Store the user's unsent message editor text in DbThread so it survives quitting and reloading Zed. The draft flows through Thread → AcpThread → AcpThreadView on load, and back via a debounced observer on the message editor for saves. Currently works for native Zed agents only; external ACP agents will pick this up once general ACP history persistence lands. ## Changes - **`DbThread`** / **`Thread`**: New `draft_prompt: Option` field, included in `to_db()`/`from_db()` - **`AcpThread`**: Bridge field with getter/setter, populated during `register_session()` - **`NativeAgent::save_thread()`**: Copies draft from `AcpThread` → `Thread` before persisting - **`AcpThreadView`**: Restores draft into `MessageEditor` on load; syncs editor text → `AcpThread` via observer; debounced (500ms) Thread notify triggers DB save Co-authored-by: Anthony Eid Co-authored-by: Mikayla Maki --- crates/acp_thread/src/acp_thread.rs | 11 ++++++ crates/agent/src/agent.rs | 34 ++++++++++++++--- crates/agent/src/db.rs | 21 +++++++++++ crates/agent/src/thread.rs | 13 +++++++ crates/agent/src/thread_store.rs | 1 + .../src/connection_view/thread_view.rs | 37 +++++++++++++++++++ crates/agent_ui/src/message_editor.rs | 24 ++++++++++-- 7 files changed, 133 insertions(+), 8 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index e6da8f3f901b41c0a59d73920c3036fc72d1b906..f57ce1f4d188e260624bd90187a21890379fe6b6 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -970,6 +970,8 @@ pub struct AcpThread { pending_terminal_output: HashMap>>, pending_terminal_exit: HashMap, had_error: bool, + /// The user's unsent prompt text, persisted so it can be restored when reloading the thread. + draft_prompt: Option>, } impl From<&AcpThread> for ActionLogTelemetry { @@ -1207,6 +1209,7 @@ impl AcpThread { pending_terminal_output: HashMap::default(), pending_terminal_exit: HashMap::default(), had_error: false, + draft_prompt: None, } } @@ -1218,6 +1221,14 @@ impl AcpThread { self.prompt_capabilities.clone() } + pub fn draft_prompt(&self) -> Option<&[acp::ContentBlock]> { + self.draft_prompt.as_deref() + } + + pub fn set_draft_prompt(&mut self, prompt: Option>) { + self.draft_prompt = prompt; + } + pub fn connection(&self) -> &Rc { &self.connection } diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 0bb0f2c8790a5e07b97976ba391105554ad03307..7cf9416840a6bd2870327c9c68135857c01f7c9b 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -351,11 +351,12 @@ impl NativeAgent { let session_id = thread.id().clone(); let parent_session_id = thread.parent_thread_id(); let title = thread.title(); + let draft_prompt = thread.draft_prompt().map(Vec::from); let project = thread.project.clone(); let action_log = thread.action_log.clone(); let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); let acp_thread = cx.new(|cx| { - acp_thread::AcpThread::new( + let mut acp_thread = acp_thread::AcpThread::new( parent_session_id, title, connection, @@ -364,7 +365,9 @@ impl NativeAgent { session_id.clone(), prompt_capabilities_rx, cx, - ) + ); + acp_thread.set_draft_prompt(draft_prompt); + acp_thread }); let registry = LanguageModelRegistry::read_global(cx); @@ -844,9 +847,7 @@ impl NativeAgent { return; } - let database_future = ThreadsDatabase::connect(cx); - let (id, db_thread) = - thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx))); + let id = thread.read(cx).id().clone(); let Some(session) = self.sessions.get_mut(&id) else { return; }; @@ -860,6 +861,12 @@ impl NativeAgent { .collect::>(), ); + let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from); + let database_future = ThreadsDatabase::connect(cx); + let db_thread = thread.update(cx, |thread, cx| { + thread.set_draft_prompt(draft_prompt); + thread.to_db(cx) + }); let thread_store = self.thread_store.clone(); session.pending_save = cx.spawn(async move |_, cx| { let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { @@ -2571,6 +2578,18 @@ mod internal_tests { cx.run_until_parked(); + // Set a draft prompt with rich content blocks before saving. + let draft_blocks = vec![ + acp::ContentBlock::Text(acp::TextContent::new("Check out ")), + acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())), + acp::ContentBlock::Text(acp::TextContent::new(" please")), + ]; + acp_thread.update(cx, |thread, _cx| { + thread.set_draft_prompt(Some(draft_blocks.clone())); + }); + thread.update(cx, |_thread, cx| cx.notify()); + cx.run_until_parked(); + // Close the session so it can be reloaded from disk. cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await @@ -2608,6 +2627,11 @@ mod internal_tests { "} ) }); + + // Ensure the draft prompt with rich content blocks survived the round-trip. + acp_thread.read_with(cx, |thread, _| { + assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice())); + }); } fn thread_entries( diff --git a/crates/agent/src/db.rs b/crates/agent/src/db.rs index 5a14e920e52c18fb6341e09fa9f747b3c5019f1d..3a7af37cac85065d8853fbb5332093ef3fd20592 100644 --- a/crates/agent/src/db.rs +++ b/crates/agent/src/db.rs @@ -64,6 +64,8 @@ pub struct DbThread { pub thinking_enabled: bool, #[serde(default)] pub thinking_effort: Option, + #[serde(default)] + pub draft_prompt: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -105,6 +107,7 @@ impl SharedThread { speed: None, thinking_enabled: false, thinking_effort: None, + draft_prompt: None, } } @@ -282,6 +285,7 @@ impl DbThread { speed: None, thinking_enabled: false, thinking_effort: None, + draft_prompt: None, }) } } @@ -632,6 +636,7 @@ mod tests { speed: None, thinking_enabled: false, thinking_effort: None, + draft_prompt: None, } } @@ -715,6 +720,22 @@ mod tests { ); } + #[test] + fn test_draft_prompt_defaults_to_none() { + let json = r#"{ + "title": "Old Thread", + "messages": [], + "updated_at": "2024-01-01T00:00:00Z" + }"#; + + let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize"); + + assert!( + db_thread.draft_prompt.is_none(), + "Legacy threads without draft_prompt field should default to None" + ); + } + #[gpui::test] async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) { let database = ThreadsDatabase::new(cx.executor()).unwrap(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 4c43a66fe5bb67c11fe5f0438d54cc86a498c55c..c5ca1118ace28b66d555d67aa40c718da292f644 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -899,6 +899,8 @@ pub struct Thread { imported: bool, /// If this is a subagent thread, contains context about the parent subagent_context: Option, + /// The user's unsent prompt text, persisted so it can be restored when reloading the thread. + draft_prompt: Option>, /// Weak references to running subagent threads for cancellation propagation running_subagents: Vec>, } @@ -1014,6 +1016,7 @@ impl Thread { file_read_times: HashMap::default(), imported: false, subagent_context: None, + draft_prompt: None, running_subagents: Vec::new(), } } @@ -1229,6 +1232,7 @@ impl Thread { file_read_times: HashMap::default(), imported: db_thread.imported, subagent_context: db_thread.subagent_context, + draft_prompt: db_thread.draft_prompt, running_subagents: Vec::new(), } } @@ -1253,6 +1257,7 @@ impl Thread { speed: self.speed, thinking_enabled: self.thinking_enabled, thinking_effort: self.thinking_effort.clone(), + draft_prompt: self.draft_prompt.clone(), }; cx.background_spawn(async move { @@ -1294,6 +1299,14 @@ impl Thread { self.messages.is_empty() && self.title.is_none() } + pub fn draft_prompt(&self) -> Option<&[acp::ContentBlock]> { + self.draft_prompt.as_deref() + } + + pub fn set_draft_prompt(&mut self, prompt: Option>) { + self.draft_prompt = prompt; + } + pub fn model(&self) -> Option<&Arc> { self.model.as_ref() } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 5cdce12125da8f7d26677388169e899f94b7e7f1..f944377e489a88ac0fa6dbb802edf9702e86f5f2 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -145,6 +145,7 @@ mod tests { speed: None, thinking_enabled: false, thinking_effort: None, + draft_prompt: None, } } diff --git a/crates/agent_ui/src/connection_view/thread_view.rs b/crates/agent_ui/src/connection_view/thread_view.rs index b8403f8052e32fbeeceb4594438eecf32aa4e2e7..2544305bc8f8666b897d11285ffa7711f3af8794 100644 --- a/crates/agent_ui/src/connection_view/thread_view.rs +++ b/crates/agent_ui/src/connection_view/thread_view.rs @@ -5,6 +5,7 @@ use gpui::{Corner, List}; use language_model::{LanguageModelEffortLevel, Speed}; use settings::update_settings_file; use ui::{ButtonLike, SplitButton, SplitButtonStyle, Tab}; +use workspace::SERIALIZATION_THROTTLE_TIME; use super::*; @@ -239,6 +240,7 @@ pub struct ThreadView { pub resumed_without_history: bool, pub resume_thread_metadata: Option, pub _cancel_task: Option>, + _draft_save_task: Option>, pub skip_queue_processing_count: usize, pub user_interrupted_generation: bool, pub can_fast_track_queue: bool, @@ -345,6 +347,8 @@ impl ThreadView { editor.set_message(blocks, window, cx); } } + } else if let Some(draft) = thread.read(cx).draft_prompt() { + editor.set_message(draft.to_vec(), window, cx); } editor }); @@ -377,6 +381,38 @@ impl ThreadView { Self::handle_message_editor_event, )); + subscriptions.push(cx.observe(&message_editor, |this, editor, cx| { + let is_empty = editor.read(cx).text(cx).is_empty(); + let draft_contents_task = if is_empty { + None + } else { + Some(editor.update(cx, |editor, cx| editor.draft_contents(cx))) + }; + this._draft_save_task = Some(cx.spawn(async move |this, cx| { + let draft = if let Some(task) = draft_contents_task { + let blocks = task.await.ok().filter(|b| !b.is_empty()); + blocks + } else { + None + }; + this.update(cx, |this, cx| { + this.thread.update(cx, |thread, _cx| { + thread.set_draft_prompt(draft); + }); + }) + .ok(); + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update(cx, |this, cx| { + if let Some(thread) = this.as_native_thread(cx) { + thread.update(cx, |_thread, cx| cx.notify()); + } + }) + .ok(); + })); + })); + let recent_history_entries = history.read(cx).get_recent_sessions(3); let mut this = Self { @@ -427,6 +463,7 @@ impl ThreadView { is_loading_contents: false, new_server_version_available: None, _cancel_task: None, + _draft_save_task: None, skip_queue_processing_count: 0, user_interrupted_generation: false, can_fast_track_queue: false, diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 274b076eafbcfab4620c66c027c374025242f821..50b297847b43e4d147978fbcf14dce492fc572d0 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -416,7 +416,27 @@ impl MessageEditor { let text = self.editor.read(cx).text(cx); let available_commands = self.available_commands.borrow().clone(); let agent_name = self.agent_name.clone(); + let build_task = self.build_content_blocks(full_mention_content, cx); + cx.spawn(async move |_, _cx| { + Self::validate_slash_commands(&text, &available_commands, &agent_name)?; + build_task.await + }) + } + + pub fn draft_contents(&self, cx: &mut Context) -> Task>> { + let build_task = self.build_content_blocks(false, cx); + cx.spawn(async move |_, _cx| { + let (blocks, _tracked_buffers) = build_task.await?; + Ok(blocks) + }) + } + + fn build_content_blocks( + &self, + full_mention_content: bool, + cx: &mut Context, + ) -> Task, Vec>)>> { let contents = self .mention_set .update(cx, |store, cx| store.contents(full_mention_content, cx)); @@ -424,18 +444,16 @@ impl MessageEditor { let supports_embedded_context = self.prompt_capabilities.borrow().embedded_context; cx.spawn(async move |_, cx| { - Self::validate_slash_commands(&text, &available_commands, &agent_name)?; - let contents = contents.await?; let mut all_tracked_buffers = Vec::new(); let result = editor.update(cx, |editor, cx| { + let text = editor.text(cx); let (mut ix, _) = text .char_indices() .find(|(_, c)| !c.is_whitespace()) .unwrap_or((0, '\0')); let mut chunks: Vec = Vec::new(); - let text = editor.text(cx); editor.display_map.update(cx, |map, cx| { let snapshot = map.snapshot(cx); for (crease_id, crease) in snapshot.crease_snapshot.creases() {