1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26/// Metadata about the git worktree associated with an agent thread.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct AgentGitWorktreeInfo {
29 /// The branch name in the git worktree.
30 pub branch: String,
31 /// Absolute path to the git worktree on disk.
32 pub worktree_path: std::path::PathBuf,
33 /// The base branch/commit the worktree was created from.
34 pub base_ref: String,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct DbThreadMetadata {
39 pub id: acp::SessionId,
40 pub parent_session_id: Option<acp::SessionId>,
41 #[serde(alias = "summary")]
42 pub title: SharedString,
43 pub updated_at: DateTime<Utc>,
44 /// Denormalized from `DbThread::git_worktree_info.branch` for efficient
45 /// listing without decompressing thread data. The blob is the source of
46 /// truth; this column is populated on save for query convenience.
47 pub worktree_branch: Option<String>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51pub struct DbThread {
52 pub title: SharedString,
53 pub messages: Vec<DbMessage>,
54 pub updated_at: DateTime<Utc>,
55 #[serde(default)]
56 pub detailed_summary: Option<SharedString>,
57 #[serde(default)]
58 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
59 #[serde(default)]
60 pub cumulative_token_usage: language_model::TokenUsage,
61 #[serde(default)]
62 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
63 #[serde(default)]
64 pub model: Option<DbLanguageModel>,
65 #[serde(default)]
66 pub profile: Option<AgentProfileId>,
67 #[serde(default)]
68 pub imported: bool,
69 #[serde(default)]
70 pub subagent_context: Option<crate::SubagentContext>,
71 #[serde(default)]
72 pub git_worktree_info: Option<AgentGitWorktreeInfo>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SharedThread {
77 pub title: SharedString,
78 pub messages: Vec<DbMessage>,
79 pub updated_at: DateTime<Utc>,
80 #[serde(default)]
81 pub model: Option<DbLanguageModel>,
82 pub version: String,
83}
84
85impl SharedThread {
86 pub const VERSION: &'static str = "1.0.0";
87
88 pub fn from_db_thread(thread: &DbThread) -> Self {
89 Self {
90 title: thread.title.clone(),
91 messages: thread.messages.clone(),
92 updated_at: thread.updated_at,
93 model: thread.model.clone(),
94 version: Self::VERSION.to_string(),
95 }
96 }
97
98 pub fn to_db_thread(self) -> DbThread {
99 DbThread {
100 title: format!("🔗 {}", self.title).into(),
101 messages: self.messages,
102 updated_at: self.updated_at,
103 detailed_summary: None,
104 initial_project_snapshot: None,
105 cumulative_token_usage: Default::default(),
106 request_token_usage: Default::default(),
107 model: self.model,
108 profile: None,
109 imported: true,
110 subagent_context: None,
111 git_worktree_info: None,
112 }
113 }
114
115 pub fn to_bytes(&self) -> Result<Vec<u8>> {
116 const COMPRESSION_LEVEL: i32 = 3;
117 let json = serde_json::to_vec(self)?;
118 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
119 Ok(compressed)
120 }
121
122 pub fn from_bytes(data: &[u8]) -> Result<Self> {
123 let decompressed = zstd::decode_all(data)?;
124 Ok(serde_json::from_slice(&decompressed)?)
125 }
126}
127
128impl DbThread {
129 pub const VERSION: &'static str = "0.3.0";
130
131 pub fn from_json(json: &[u8]) -> Result<Self> {
132 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
133 match saved_thread_json.get("version") {
134 Some(serde_json::Value::String(version)) => match version.as_str() {
135 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
136 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
137 json,
138 )?),
139 },
140 _ => {
141 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
142 }
143 }
144 }
145
146 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
147 let mut messages = Vec::new();
148 let mut request_token_usage = HashMap::default();
149
150 let mut last_user_message_id = None;
151 for (ix, msg) in thread.messages.into_iter().enumerate() {
152 let message = match msg.role {
153 language_model::Role::User => {
154 let mut content = Vec::new();
155
156 // Convert segments to content
157 for segment in msg.segments {
158 match segment {
159 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
160 content.push(UserMessageContent::Text(text));
161 }
162 crate::legacy_thread::SerializedMessageSegment::Thinking {
163 text,
164 ..
165 } => {
166 // User messages don't have thinking segments, but handle gracefully
167 content.push(UserMessageContent::Text(text));
168 }
169 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
170 ..
171 } => {
172 // User messages don't have redacted thinking, skip.
173 }
174 }
175 }
176
177 // If no content was added, add context as text if available
178 if content.is_empty() && !msg.context.is_empty() {
179 content.push(UserMessageContent::Text(msg.context));
180 }
181
182 let id = UserMessageId::new();
183 last_user_message_id = Some(id.clone());
184
185 crate::Message::User(UserMessage {
186 // MessageId from old format can't be meaningfully converted, so generate a new one
187 id,
188 content,
189 })
190 }
191 language_model::Role::Assistant => {
192 let mut content = Vec::new();
193
194 // Convert segments to content
195 for segment in msg.segments {
196 match segment {
197 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
198 content.push(AgentMessageContent::Text(text));
199 }
200 crate::legacy_thread::SerializedMessageSegment::Thinking {
201 text,
202 signature,
203 } => {
204 content.push(AgentMessageContent::Thinking { text, signature });
205 }
206 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
207 data,
208 } => {
209 content.push(AgentMessageContent::RedactedThinking(data));
210 }
211 }
212 }
213
214 // Convert tool uses
215 let mut tool_names_by_id = HashMap::default();
216 for tool_use in msg.tool_uses {
217 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
218 content.push(AgentMessageContent::ToolUse(
219 language_model::LanguageModelToolUse {
220 id: tool_use.id,
221 name: tool_use.name.into(),
222 raw_input: serde_json::to_string(&tool_use.input)
223 .unwrap_or_default(),
224 input: tool_use.input,
225 is_input_complete: true,
226 thought_signature: None,
227 },
228 ));
229 }
230
231 // Convert tool results
232 let mut tool_results = IndexMap::default();
233 for tool_result in msg.tool_results {
234 let name = tool_names_by_id
235 .remove(&tool_result.tool_use_id)
236 .unwrap_or_else(|| SharedString::from("unknown"));
237 tool_results.insert(
238 tool_result.tool_use_id.clone(),
239 language_model::LanguageModelToolResult {
240 tool_use_id: tool_result.tool_use_id,
241 tool_name: name.into(),
242 is_error: tool_result.is_error,
243 content: tool_result.content,
244 output: tool_result.output,
245 },
246 );
247 }
248
249 if let Some(last_user_message_id) = &last_user_message_id
250 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
251 {
252 request_token_usage.insert(last_user_message_id.clone(), token_usage);
253 }
254
255 crate::Message::Agent(AgentMessage {
256 content,
257 tool_results,
258 reasoning_details: None,
259 })
260 }
261 language_model::Role::System => {
262 // Skip system messages as they're not supported in the new format
263 continue;
264 }
265 };
266
267 messages.push(message);
268 }
269
270 Ok(Self {
271 title: thread.summary,
272 messages,
273 updated_at: thread.updated_at,
274 detailed_summary: match thread.detailed_summary_state {
275 crate::legacy_thread::DetailedSummaryState::NotGenerated
276 | crate::legacy_thread::DetailedSummaryState::Generating => None,
277 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
278 },
279 initial_project_snapshot: thread.initial_project_snapshot,
280 cumulative_token_usage: thread.cumulative_token_usage,
281 request_token_usage,
282 model: thread.model,
283 profile: thread.profile,
284 imported: false,
285 subagent_context: None,
286 git_worktree_info: None,
287 })
288 }
289}
290
291#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292pub enum DataType {
293 #[serde(rename = "json")]
294 Json,
295 #[serde(rename = "zstd")]
296 Zstd,
297}
298
299impl Bind for DataType {
300 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
301 let value = match self {
302 DataType::Json => "json",
303 DataType::Zstd => "zstd",
304 };
305 value.bind(statement, start_index)
306 }
307}
308
309impl Column for DataType {
310 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
311 let (value, next_index) = String::column(statement, start_index)?;
312 let data_type = match value.as_str() {
313 "json" => DataType::Json,
314 "zstd" => DataType::Zstd,
315 _ => anyhow::bail!("Unknown data type: {}", value),
316 };
317 Ok((data_type, next_index))
318 }
319}
320
321pub(crate) struct ThreadsDatabase {
322 executor: BackgroundExecutor,
323 connection: Arc<Mutex<Connection>>,
324}
325
326struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
327
328impl Global for GlobalThreadsDatabase {}
329
330impl ThreadsDatabase {
331 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
332 if cx.has_global::<GlobalThreadsDatabase>() {
333 return cx.global::<GlobalThreadsDatabase>().0.clone();
334 }
335 let executor = cx.background_executor().clone();
336 let task = executor
337 .spawn({
338 let executor = executor.clone();
339 async move {
340 match ThreadsDatabase::new(executor) {
341 Ok(db) => Ok(Arc::new(db)),
342 Err(err) => Err(Arc::new(err)),
343 }
344 }
345 })
346 .shared();
347
348 cx.set_global(GlobalThreadsDatabase(task.clone()));
349 task
350 }
351
352 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
353 let connection = if *ZED_STATELESS {
354 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
355 } else if cfg!(any(feature = "test-support", test)) {
356 // rust stores the name of the test on the current thread.
357 // We use this to automatically create a database that will
358 // be shared within the test (for the test_retrieve_old_thread)
359 // but not with concurrent tests.
360 let thread = std::thread::current();
361 let test_name = thread.name();
362 Connection::open_memory(Some(&format!(
363 "THREAD_FALLBACK_{}",
364 test_name.unwrap_or_default()
365 )))
366 } else {
367 let threads_dir = paths::data_dir().join("threads");
368 std::fs::create_dir_all(&threads_dir)?;
369 let sqlite_path = threads_dir.join("threads.db");
370 Connection::open_file(&sqlite_path.to_string_lossy())
371 };
372
373 connection.exec(indoc! {"
374 CREATE TABLE IF NOT EXISTS threads (
375 id TEXT PRIMARY KEY,
376 summary TEXT NOT NULL,
377 updated_at TEXT NOT NULL,
378 data_type TEXT NOT NULL,
379 data BLOB NOT NULL
380 )
381 "})?()
382 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
383
384 if let Ok(mut s) = connection.exec(indoc! {"
385 ALTER TABLE threads ADD COLUMN parent_id TEXT
386 "})
387 {
388 s().ok();
389 }
390
391 if let Ok(mut s) = connection.exec(indoc! {"
392 ALTER TABLE threads ADD COLUMN worktree_branch TEXT
393 "})
394 {
395 s().ok();
396 }
397
398 let db = Self {
399 executor,
400 connection: Arc::new(Mutex::new(connection)),
401 };
402
403 Ok(db)
404 }
405
406 fn save_thread_sync(
407 connection: &Arc<Mutex<Connection>>,
408 id: acp::SessionId,
409 thread: DbThread,
410 ) -> Result<()> {
411 const COMPRESSION_LEVEL: i32 = 3;
412
413 #[derive(Serialize)]
414 struct SerializedThread {
415 #[serde(flatten)]
416 thread: DbThread,
417 version: &'static str,
418 }
419
420 let title = thread.title.to_string();
421 let updated_at = thread.updated_at.to_rfc3339();
422 let parent_id = thread
423 .subagent_context
424 .as_ref()
425 .map(|ctx| ctx.parent_thread_id.0.clone());
426 let worktree_branch = thread
427 .git_worktree_info
428 .as_ref()
429 .map(|info| info.branch.clone());
430 let json_data = serde_json::to_string(&SerializedThread {
431 thread,
432 version: DbThread::VERSION,
433 })?;
434
435 let connection = connection.lock();
436
437 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
438 let data_type = DataType::Zstd;
439 let data = compressed;
440
441 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
442 INSERT OR REPLACE INTO threads (id, parent_id, worktree_branch, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?)
443 "})?;
444
445 insert((
446 id.0,
447 parent_id,
448 worktree_branch,
449 title,
450 updated_at,
451 data_type,
452 data,
453 ))?;
454
455 Ok(())
456 }
457
458 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
459 let connection = self.connection.clone();
460
461 self.executor.spawn(async move {
462 let connection = connection.lock();
463
464 let mut select = connection
465 .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, String, String)>(indoc! {"
466 SELECT id, parent_id, worktree_branch, summary, updated_at FROM threads ORDER BY updated_at DESC
467 "})?;
468
469 let rows = select(())?;
470 let mut threads = Vec::new();
471
472 for (id, parent_id, worktree_branch, summary, updated_at) in rows {
473 threads.push(DbThreadMetadata {
474 id: acp::SessionId::new(id),
475 parent_session_id: parent_id.map(acp::SessionId::new),
476 title: summary.into(),
477 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
478 worktree_branch,
479 });
480 }
481
482 Ok(threads)
483 })
484 }
485
486 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
487 let connection = self.connection.clone();
488
489 self.executor.spawn(async move {
490 let connection = connection.lock();
491 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
492 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
493 "})?;
494
495 let rows = select(id.0)?;
496 if let Some((data_type, data)) = rows.into_iter().next() {
497 let json_data = match data_type {
498 DataType::Zstd => {
499 let decompressed = zstd::decode_all(&data[..])?;
500 String::from_utf8(decompressed)?
501 }
502 DataType::Json => String::from_utf8(data)?,
503 };
504 let thread = DbThread::from_json(json_data.as_bytes())?;
505 Ok(Some(thread))
506 } else {
507 Ok(None)
508 }
509 })
510 }
511
512 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
513 let connection = self.connection.clone();
514
515 self.executor
516 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
517 }
518
519 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
520 let connection = self.connection.clone();
521
522 self.executor.spawn(async move {
523 let connection = connection.lock();
524
525 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
526 DELETE FROM threads WHERE id = ?
527 "})?;
528
529 delete(id.0)?;
530
531 Ok(())
532 })
533 }
534
535 pub fn delete_threads(&self) -> Task<Result<()>> {
536 let connection = self.connection.clone();
537
538 self.executor.spawn(async move {
539 let connection = connection.lock();
540
541 let mut delete = connection.exec_bound::<()>(indoc! {"
542 DELETE FROM threads
543 "})?;
544
545 delete(())?;
546
547 Ok(())
548 })
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use chrono::{DateTime, TimeZone, Utc};
556 use collections::HashMap;
557 use gpui::TestAppContext;
558 use std::sync::Arc;
559
560 #[test]
561 fn test_shared_thread_roundtrip() {
562 let original = SharedThread {
563 title: "Test Thread".into(),
564 messages: vec![],
565 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
566 model: None,
567 version: SharedThread::VERSION.to_string(),
568 };
569
570 let bytes = original.to_bytes().expect("Failed to serialize");
571 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
572
573 assert_eq!(restored.title, original.title);
574 assert_eq!(restored.version, original.version);
575 assert_eq!(restored.updated_at, original.updated_at);
576 }
577
578 #[test]
579 fn test_imported_flag_defaults_to_false() {
580 // Simulate deserializing a thread without the imported field (backwards compatibility).
581 let json = r#"{
582 "title": "Old Thread",
583 "messages": [],
584 "updated_at": "2024-01-01T00:00:00Z"
585 }"#;
586
587 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
588
589 assert!(
590 !db_thread.imported,
591 "Legacy threads without imported field should default to false"
592 );
593 }
594
595 fn session_id(value: &str) -> acp::SessionId {
596 acp::SessionId::new(Arc::<str>::from(value))
597 }
598
599 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
600 DbThread {
601 title: title.to_string().into(),
602 messages: Vec::new(),
603 updated_at,
604 detailed_summary: None,
605 initial_project_snapshot: None,
606 cumulative_token_usage: Default::default(),
607 request_token_usage: HashMap::default(),
608 model: None,
609 profile: None,
610 imported: false,
611 subagent_context: None,
612 git_worktree_info: None,
613 }
614 }
615
616 #[gpui::test]
617 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
618 let database = ThreadsDatabase::new(cx.executor()).unwrap();
619
620 let older_id = session_id("thread-a");
621 let newer_id = session_id("thread-b");
622
623 let older_thread = make_thread(
624 "Thread A",
625 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
626 );
627 let newer_thread = make_thread(
628 "Thread B",
629 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
630 );
631
632 database
633 .save_thread(older_id.clone(), older_thread)
634 .await
635 .unwrap();
636 database
637 .save_thread(newer_id.clone(), newer_thread)
638 .await
639 .unwrap();
640
641 let entries = database.list_threads().await.unwrap();
642 assert_eq!(entries.len(), 2);
643 assert_eq!(entries[0].id, newer_id);
644 assert_eq!(entries[1].id, older_id);
645 }
646
647 #[gpui::test]
648 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
649 let database = ThreadsDatabase::new(cx.executor()).unwrap();
650
651 let thread_id = session_id("thread-a");
652 let original_thread = make_thread(
653 "Thread A",
654 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
655 );
656 let updated_thread = make_thread(
657 "Thread B",
658 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
659 );
660
661 database
662 .save_thread(thread_id.clone(), original_thread)
663 .await
664 .unwrap();
665 database
666 .save_thread(thread_id.clone(), updated_thread)
667 .await
668 .unwrap();
669
670 let entries = database.list_threads().await.unwrap();
671 assert_eq!(entries.len(), 1);
672 assert_eq!(entries[0].id, thread_id);
673 assert_eq!(entries[0].title.as_ref(), "Thread B");
674 assert_eq!(
675 entries[0].updated_at,
676 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
677 );
678 }
679
680 #[test]
681 fn test_subagent_context_defaults_to_none() {
682 let json = r#"{
683 "title": "Old Thread",
684 "messages": [],
685 "updated_at": "2024-01-01T00:00:00Z"
686 }"#;
687
688 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
689
690 assert!(
691 db_thread.subagent_context.is_none(),
692 "Legacy threads without subagent_context should default to None"
693 );
694 }
695
696 #[gpui::test]
697 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
698 let database = ThreadsDatabase::new(cx.executor()).unwrap();
699
700 let parent_id = session_id("parent-thread");
701 let child_id = session_id("child-thread");
702
703 let mut child_thread = make_thread(
704 "Subagent Thread",
705 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
706 );
707 child_thread.subagent_context = Some(crate::SubagentContext {
708 parent_thread_id: parent_id.clone(),
709 depth: 2,
710 });
711
712 database
713 .save_thread(child_id.clone(), child_thread)
714 .await
715 .unwrap();
716
717 let loaded = database
718 .load_thread(child_id)
719 .await
720 .unwrap()
721 .expect("thread should exist");
722
723 let context = loaded
724 .subagent_context
725 .expect("subagent_context should be restored");
726 assert_eq!(context.parent_thread_id, parent_id);
727 assert_eq!(context.depth, 2);
728 }
729
730 #[gpui::test]
731 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
732 let database = ThreadsDatabase::new(cx.executor()).unwrap();
733
734 let thread_id = session_id("regular-thread");
735 let thread = make_thread(
736 "Regular Thread",
737 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
738 );
739
740 database
741 .save_thread(thread_id.clone(), thread)
742 .await
743 .unwrap();
744
745 let loaded = database
746 .load_thread(thread_id)
747 .await
748 .unwrap()
749 .expect("thread should exist");
750
751 assert!(
752 loaded.subagent_context.is_none(),
753 "Regular threads should have no subagent_context"
754 );
755 }
756
757 #[gpui::test]
758 async fn test_git_worktree_info_roundtrip(cx: &mut TestAppContext) {
759 let database = ThreadsDatabase::new(cx.executor()).unwrap();
760
761 let thread_id = session_id("worktree-thread");
762 let mut thread = make_thread(
763 "Worktree Thread",
764 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
765 );
766 thread.git_worktree_info = Some(AgentGitWorktreeInfo {
767 branch: "zed/agent/a4Xiu".to_string(),
768 worktree_path: std::path::PathBuf::from("/repo/worktrees/zed/agent/a4Xiu"),
769 base_ref: "main".to_string(),
770 });
771
772 database
773 .save_thread(thread_id.clone(), thread)
774 .await
775 .unwrap();
776
777 let loaded = database
778 .load_thread(thread_id)
779 .await
780 .unwrap()
781 .expect("thread should exist");
782
783 let info = loaded
784 .git_worktree_info
785 .expect("git_worktree_info should be restored");
786 assert_eq!(info.branch, "zed/agent/a4Xiu");
787 assert_eq!(
788 info.worktree_path,
789 std::path::PathBuf::from("/repo/worktrees/zed/agent/a4Xiu")
790 );
791 assert_eq!(info.base_ref, "main");
792 }
793
794 #[gpui::test]
795 async fn test_session_list_includes_worktree_meta(cx: &mut TestAppContext) {
796 let database = ThreadsDatabase::new(cx.executor()).unwrap();
797
798 // Save a thread with worktree info
799 let worktree_id = session_id("wt-thread");
800 let mut worktree_thread = make_thread(
801 "With Worktree",
802 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
803 );
804 worktree_thread.git_worktree_info = Some(AgentGitWorktreeInfo {
805 branch: "zed/agent/bR9kz".to_string(),
806 worktree_path: std::path::PathBuf::from("/repo/worktrees/zed/agent/bR9kz"),
807 base_ref: "develop".to_string(),
808 });
809
810 database
811 .save_thread(worktree_id.clone(), worktree_thread)
812 .await
813 .unwrap();
814
815 // Save a thread without worktree info
816 let plain_id = session_id("plain-thread");
817 let plain_thread = make_thread(
818 "Without Worktree",
819 Utc.with_ymd_and_hms(2024, 6, 15, 11, 0, 0).unwrap(),
820 );
821
822 database
823 .save_thread(plain_id.clone(), plain_thread)
824 .await
825 .unwrap();
826
827 // List threads and verify worktree_branch is populated correctly
828 let threads = database.list_threads().await.unwrap();
829 assert_eq!(threads.len(), 2);
830
831 let wt_entry = threads
832 .iter()
833 .find(|t| t.id == worktree_id)
834 .expect("should find worktree thread");
835 assert_eq!(wt_entry.worktree_branch.as_deref(), Some("zed/agent/bR9kz"));
836
837 let plain_entry = threads
838 .iter()
839 .find(|t| t.id == plain_id)
840 .expect("should find plain thread");
841 assert!(
842 plain_entry.worktree_branch.is_none(),
843 "plain thread should have no worktree_branch"
844 );
845 }
846}