1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use language_model::Speed;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use util::path_list::PathList;
22use zed_env_vars::ZED_STATELESS;
23
24pub type DbMessage = crate::Message;
25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct DbThreadMetadata {
30 pub id: acp::SessionId,
31 pub parent_session_id: Option<acp::SessionId>,
32 #[serde(alias = "summary")]
33 pub title: SharedString,
34 pub updated_at: DateTime<Utc>,
35 /// The workspace folder paths this thread was created against, sorted
36 /// lexicographically. Used for grouping threads by project in the sidebar.
37 pub folder_paths: PathList,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41pub struct DbThread {
42 pub title: SharedString,
43 pub messages: Vec<DbMessage>,
44 pub updated_at: DateTime<Utc>,
45 #[serde(default)]
46 pub detailed_summary: Option<SharedString>,
47 #[serde(default)]
48 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
49 #[serde(default)]
50 pub cumulative_token_usage: language_model::TokenUsage,
51 #[serde(default)]
52 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
53 #[serde(default)]
54 pub model: Option<DbLanguageModel>,
55 #[serde(default)]
56 pub profile: Option<AgentProfileId>,
57 #[serde(default)]
58 pub imported: bool,
59 #[serde(default)]
60 pub subagent_context: Option<crate::SubagentContext>,
61 #[serde(default)]
62 pub speed: Option<Speed>,
63 #[serde(default)]
64 pub thinking_enabled: bool,
65 #[serde(default)]
66 pub thinking_effort: Option<String>,
67 #[serde(default)]
68 pub draft_prompt: Option<Vec<acp::ContentBlock>>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct SharedThread {
73 pub title: SharedString,
74 pub messages: Vec<DbMessage>,
75 pub updated_at: DateTime<Utc>,
76 #[serde(default)]
77 pub model: Option<DbLanguageModel>,
78 pub version: String,
79}
80
81impl SharedThread {
82 pub const VERSION: &'static str = "1.0.0";
83
84 pub fn from_db_thread(thread: &DbThread) -> Self {
85 Self {
86 title: thread.title.clone(),
87 messages: thread.messages.clone(),
88 updated_at: thread.updated_at,
89 model: thread.model.clone(),
90 version: Self::VERSION.to_string(),
91 }
92 }
93
94 pub fn to_db_thread(self) -> DbThread {
95 DbThread {
96 title: format!("🔗 {}", self.title).into(),
97 messages: self.messages,
98 updated_at: self.updated_at,
99 detailed_summary: None,
100 initial_project_snapshot: None,
101 cumulative_token_usage: Default::default(),
102 request_token_usage: Default::default(),
103 model: self.model,
104 profile: None,
105 imported: true,
106 subagent_context: None,
107 speed: None,
108 thinking_enabled: false,
109 thinking_effort: None,
110 draft_prompt: None,
111 }
112 }
113
114 pub fn to_bytes(&self) -> Result<Vec<u8>> {
115 const COMPRESSION_LEVEL: i32 = 3;
116 let json = serde_json::to_vec(self)?;
117 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
118 Ok(compressed)
119 }
120
121 pub fn from_bytes(data: &[u8]) -> Result<Self> {
122 let decompressed = zstd::decode_all(data)?;
123 Ok(serde_json::from_slice(&decompressed)?)
124 }
125}
126
127impl DbThread {
128 pub const VERSION: &'static str = "0.3.0";
129
130 pub fn from_json(json: &[u8]) -> Result<Self> {
131 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
132 match saved_thread_json.get("version") {
133 Some(serde_json::Value::String(version)) => match version.as_str() {
134 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
135 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
136 json,
137 )?),
138 },
139 _ => {
140 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
141 }
142 }
143 }
144
145 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
146 let mut messages = Vec::new();
147 let mut request_token_usage = HashMap::default();
148
149 let mut last_user_message_id = None;
150 for (ix, msg) in thread.messages.into_iter().enumerate() {
151 let message = match msg.role {
152 language_model::Role::User => {
153 let mut content = Vec::new();
154
155 // Convert segments to content
156 for segment in msg.segments {
157 match segment {
158 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
159 content.push(UserMessageContent::Text(text));
160 }
161 crate::legacy_thread::SerializedMessageSegment::Thinking {
162 text,
163 ..
164 } => {
165 // User messages don't have thinking segments, but handle gracefully
166 content.push(UserMessageContent::Text(text));
167 }
168 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
169 ..
170 } => {
171 // User messages don't have redacted thinking, skip.
172 }
173 }
174 }
175
176 // If no content was added, add context as text if available
177 if content.is_empty() && !msg.context.is_empty() {
178 content.push(UserMessageContent::Text(msg.context));
179 }
180
181 let id = UserMessageId::new();
182 last_user_message_id = Some(id.clone());
183
184 crate::Message::User(UserMessage {
185 // MessageId from old format can't be meaningfully converted, so generate a new one
186 id,
187 content,
188 })
189 }
190 language_model::Role::Assistant => {
191 let mut content = Vec::new();
192
193 // Convert segments to content
194 for segment in msg.segments {
195 match segment {
196 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
197 content.push(AgentMessageContent::Text(text));
198 }
199 crate::legacy_thread::SerializedMessageSegment::Thinking {
200 text,
201 signature,
202 } => {
203 content.push(AgentMessageContent::Thinking { text, signature });
204 }
205 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
206 data,
207 } => {
208 content.push(AgentMessageContent::RedactedThinking(data));
209 }
210 }
211 }
212
213 // Convert tool uses
214 let mut tool_names_by_id = HashMap::default();
215 for tool_use in msg.tool_uses {
216 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
217 content.push(AgentMessageContent::ToolUse(
218 language_model::LanguageModelToolUse {
219 id: tool_use.id,
220 name: tool_use.name.into(),
221 raw_input: serde_json::to_string(&tool_use.input)
222 .unwrap_or_default(),
223 input: tool_use.input,
224 is_input_complete: true,
225 thought_signature: None,
226 },
227 ));
228 }
229
230 // Convert tool results
231 let mut tool_results = IndexMap::default();
232 for tool_result in msg.tool_results {
233 let name = tool_names_by_id
234 .remove(&tool_result.tool_use_id)
235 .unwrap_or_else(|| SharedString::from("unknown"));
236 tool_results.insert(
237 tool_result.tool_use_id.clone(),
238 language_model::LanguageModelToolResult {
239 tool_use_id: tool_result.tool_use_id,
240 tool_name: name.into(),
241 is_error: tool_result.is_error,
242 content: tool_result.content,
243 output: tool_result.output,
244 },
245 );
246 }
247
248 if let Some(last_user_message_id) = &last_user_message_id
249 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
250 {
251 request_token_usage.insert(last_user_message_id.clone(), token_usage);
252 }
253
254 crate::Message::Agent(AgentMessage {
255 content,
256 tool_results,
257 reasoning_details: None,
258 })
259 }
260 language_model::Role::System => {
261 // Skip system messages as they're not supported in the new format
262 continue;
263 }
264 };
265
266 messages.push(message);
267 }
268
269 Ok(Self {
270 title: thread.summary,
271 messages,
272 updated_at: thread.updated_at,
273 detailed_summary: match thread.detailed_summary_state {
274 crate::legacy_thread::DetailedSummaryState::NotGenerated
275 | crate::legacy_thread::DetailedSummaryState::Generating => None,
276 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
277 },
278 initial_project_snapshot: thread.initial_project_snapshot,
279 cumulative_token_usage: thread.cumulative_token_usage,
280 request_token_usage,
281 model: thread.model,
282 profile: thread.profile,
283 imported: false,
284 subagent_context: None,
285 speed: None,
286 thinking_enabled: false,
287 thinking_effort: None,
288 draft_prompt: None,
289 })
290 }
291}
292
293#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
294pub enum DataType {
295 #[serde(rename = "json")]
296 Json,
297 #[serde(rename = "zstd")]
298 Zstd,
299}
300
301impl Bind for DataType {
302 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
303 let value = match self {
304 DataType::Json => "json",
305 DataType::Zstd => "zstd",
306 };
307 value.bind(statement, start_index)
308 }
309}
310
311impl Column for DataType {
312 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
313 let (value, next_index) = String::column(statement, start_index)?;
314 let data_type = match value.as_str() {
315 "json" => DataType::Json,
316 "zstd" => DataType::Zstd,
317 _ => anyhow::bail!("Unknown data type: {}", value),
318 };
319 Ok((data_type, next_index))
320 }
321}
322
323pub(crate) struct ThreadsDatabase {
324 executor: BackgroundExecutor,
325 connection: Arc<Mutex<Connection>>,
326}
327
328struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
329
330impl Global for GlobalThreadsDatabase {}
331
332impl ThreadsDatabase {
333 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
334 if cx.has_global::<GlobalThreadsDatabase>() {
335 return cx.global::<GlobalThreadsDatabase>().0.clone();
336 }
337 let executor = cx.background_executor().clone();
338 let task = executor
339 .spawn({
340 let executor = executor.clone();
341 async move {
342 match ThreadsDatabase::new(executor) {
343 Ok(db) => Ok(Arc::new(db)),
344 Err(err) => Err(Arc::new(err)),
345 }
346 }
347 })
348 .shared();
349
350 cx.set_global(GlobalThreadsDatabase(task.clone()));
351 task
352 }
353
354 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
355 let connection = if *ZED_STATELESS {
356 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
357 } else if cfg!(any(feature = "test-support", test)) {
358 // rust stores the name of the test on the current thread.
359 // We use this to automatically create a database that will
360 // be shared within the test (for the test_retrieve_old_thread)
361 // but not with concurrent tests.
362 let thread = std::thread::current();
363 let test_name = thread.name();
364 Connection::open_memory(Some(&format!(
365 "THREAD_FALLBACK_{}",
366 test_name.unwrap_or_default()
367 )))
368 } else {
369 let threads_dir = paths::data_dir().join("threads");
370 std::fs::create_dir_all(&threads_dir)?;
371 let sqlite_path = threads_dir.join("threads.db");
372 Connection::open_file(&sqlite_path.to_string_lossy())
373 };
374
375 connection.exec(indoc! {"
376 CREATE TABLE IF NOT EXISTS threads (
377 id TEXT PRIMARY KEY,
378 summary TEXT NOT NULL,
379 updated_at TEXT NOT NULL,
380 data_type TEXT NOT NULL,
381 data BLOB NOT NULL
382 )
383 "})?()
384 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
385
386 if let Ok(mut s) = connection.exec(indoc! {"
387 ALTER TABLE threads ADD COLUMN parent_id TEXT
388 "})
389 {
390 s().ok();
391 }
392
393 if let Ok(mut s) = connection.exec(indoc! {"
394 ALTER TABLE threads ADD COLUMN folder_paths TEXT;
395 ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
396 "})
397 {
398 s().ok();
399 }
400
401 let db = Self {
402 executor,
403 connection: Arc::new(Mutex::new(connection)),
404 };
405
406 Ok(db)
407 }
408
409 fn save_thread_sync(
410 connection: &Arc<Mutex<Connection>>,
411 id: acp::SessionId,
412 thread: DbThread,
413 folder_paths: &PathList,
414 ) -> Result<()> {
415 const COMPRESSION_LEVEL: i32 = 3;
416
417 #[derive(Serialize)]
418 struct SerializedThread {
419 #[serde(flatten)]
420 thread: DbThread,
421 version: &'static str,
422 }
423
424 let title = thread.title.to_string();
425 let updated_at = thread.updated_at.to_rfc3339();
426 let parent_id = thread
427 .subagent_context
428 .as_ref()
429 .map(|ctx| ctx.parent_thread_id.0.clone());
430 let serialized_folder_paths = folder_paths.serialize();
431 let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
432 if folder_paths.is_empty() {
433 (None, None)
434 } else {
435 (
436 Some(serialized_folder_paths.paths),
437 Some(serialized_folder_paths.order),
438 )
439 };
440 let json_data = serde_json::to_string(&SerializedThread {
441 thread,
442 version: DbThread::VERSION,
443 })?;
444
445 let connection = connection.lock();
446
447 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
448 let data_type = DataType::Zstd;
449 let data = compressed;
450
451 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
452 INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
453 "})?;
454
455 insert((
456 id.0,
457 parent_id,
458 folder_paths_str,
459 folder_paths_order_str,
460 title,
461 updated_at,
462 data_type,
463 data,
464 ))?;
465
466 Ok(())
467 }
468
469 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
470 let connection = self.connection.clone();
471
472 self.executor.spawn(async move {
473 let connection = connection.lock();
474
475 let mut select = connection
476 .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
477 SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
478 "})?;
479
480 let rows = select(())?;
481 let mut threads = Vec::new();
482
483 for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
484 let folder_paths = folder_paths
485 .map(|paths| {
486 PathList::deserialize(&util::path_list::SerializedPathList {
487 paths,
488 order: folder_paths_order.unwrap_or_default(),
489 })
490 })
491 .unwrap_or_default();
492 threads.push(DbThreadMetadata {
493 id: acp::SessionId::new(id),
494 parent_session_id: parent_id.map(acp::SessionId::new),
495 title: summary.into(),
496 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
497 folder_paths,
498 });
499 }
500
501 Ok(threads)
502 })
503 }
504
505 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
506 let connection = self.connection.clone();
507
508 self.executor.spawn(async move {
509 let connection = connection.lock();
510 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
511 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
512 "})?;
513
514 let rows = select(id.0)?;
515 if let Some((data_type, data)) = rows.into_iter().next() {
516 let json_data = match data_type {
517 DataType::Zstd => {
518 let decompressed = zstd::decode_all(&data[..])?;
519 String::from_utf8(decompressed)?
520 }
521 DataType::Json => String::from_utf8(data)?,
522 };
523 let thread = DbThread::from_json(json_data.as_bytes())?;
524 Ok(Some(thread))
525 } else {
526 Ok(None)
527 }
528 })
529 }
530
531 pub fn save_thread(
532 &self,
533 id: acp::SessionId,
534 thread: DbThread,
535 folder_paths: PathList,
536 ) -> Task<Result<()>> {
537 let connection = self.connection.clone();
538
539 self.executor
540 .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
541 }
542
543 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
544 let connection = self.connection.clone();
545
546 self.executor.spawn(async move {
547 let connection = connection.lock();
548
549 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
550 DELETE FROM threads WHERE id = ?
551 "})?;
552
553 delete(id.0)?;
554
555 Ok(())
556 })
557 }
558
559 pub fn delete_threads(&self) -> Task<Result<()>> {
560 let connection = self.connection.clone();
561
562 self.executor.spawn(async move {
563 let connection = connection.lock();
564
565 let mut delete = connection.exec_bound::<()>(indoc! {"
566 DELETE FROM threads
567 "})?;
568
569 delete(())?;
570
571 Ok(())
572 })
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use chrono::{DateTime, TimeZone, Utc};
580 use collections::HashMap;
581 use gpui::TestAppContext;
582 use std::sync::Arc;
583
584 #[test]
585 fn test_shared_thread_roundtrip() {
586 let original = SharedThread {
587 title: "Test Thread".into(),
588 messages: vec![],
589 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
590 model: None,
591 version: SharedThread::VERSION.to_string(),
592 };
593
594 let bytes = original.to_bytes().expect("Failed to serialize");
595 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
596
597 assert_eq!(restored.title, original.title);
598 assert_eq!(restored.version, original.version);
599 assert_eq!(restored.updated_at, original.updated_at);
600 }
601
602 #[test]
603 fn test_imported_flag_defaults_to_false() {
604 // Simulate deserializing a thread without the imported field (backwards compatibility).
605 let json = r#"{
606 "title": "Old Thread",
607 "messages": [],
608 "updated_at": "2024-01-01T00:00:00Z"
609 }"#;
610
611 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
612
613 assert!(
614 !db_thread.imported,
615 "Legacy threads without imported field should default to false"
616 );
617 }
618
619 fn session_id(value: &str) -> acp::SessionId {
620 acp::SessionId::new(Arc::<str>::from(value))
621 }
622
623 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
624 DbThread {
625 title: title.to_string().into(),
626 messages: Vec::new(),
627 updated_at,
628 detailed_summary: None,
629 initial_project_snapshot: None,
630 cumulative_token_usage: Default::default(),
631 request_token_usage: HashMap::default(),
632 model: None,
633 profile: None,
634 imported: false,
635 subagent_context: None,
636 speed: None,
637 thinking_enabled: false,
638 thinking_effort: None,
639 draft_prompt: None,
640 }
641 }
642
643 #[gpui::test]
644 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
645 let database = ThreadsDatabase::new(cx.executor()).unwrap();
646
647 let older_id = session_id("thread-a");
648 let newer_id = session_id("thread-b");
649
650 let older_thread = make_thread(
651 "Thread A",
652 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
653 );
654 let newer_thread = make_thread(
655 "Thread B",
656 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
657 );
658
659 database
660 .save_thread(older_id.clone(), older_thread, PathList::default())
661 .await
662 .unwrap();
663 database
664 .save_thread(newer_id.clone(), newer_thread, PathList::default())
665 .await
666 .unwrap();
667
668 let entries = database.list_threads().await.unwrap();
669 assert_eq!(entries.len(), 2);
670 assert_eq!(entries[0].id, newer_id);
671 assert_eq!(entries[1].id, older_id);
672 }
673
674 #[gpui::test]
675 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
676 let database = ThreadsDatabase::new(cx.executor()).unwrap();
677
678 let thread_id = session_id("thread-a");
679 let original_thread = make_thread(
680 "Thread A",
681 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
682 );
683 let updated_thread = make_thread(
684 "Thread B",
685 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
686 );
687
688 database
689 .save_thread(thread_id.clone(), original_thread, PathList::default())
690 .await
691 .unwrap();
692 database
693 .save_thread(thread_id.clone(), updated_thread, PathList::default())
694 .await
695 .unwrap();
696
697 let entries = database.list_threads().await.unwrap();
698 assert_eq!(entries.len(), 1);
699 assert_eq!(entries[0].id, thread_id);
700 assert_eq!(entries[0].title.as_ref(), "Thread B");
701 assert_eq!(
702 entries[0].updated_at,
703 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
704 );
705 }
706
707 #[test]
708 fn test_subagent_context_defaults_to_none() {
709 let json = r#"{
710 "title": "Old Thread",
711 "messages": [],
712 "updated_at": "2024-01-01T00:00:00Z"
713 }"#;
714
715 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
716
717 assert!(
718 db_thread.subagent_context.is_none(),
719 "Legacy threads without subagent_context should default to None"
720 );
721 }
722
723 #[test]
724 fn test_draft_prompt_defaults_to_none() {
725 let json = r#"{
726 "title": "Old Thread",
727 "messages": [],
728 "updated_at": "2024-01-01T00:00:00Z"
729 }"#;
730
731 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
732
733 assert!(
734 db_thread.draft_prompt.is_none(),
735 "Legacy threads without draft_prompt field should default to None"
736 );
737 }
738
739 #[gpui::test]
740 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
741 let database = ThreadsDatabase::new(cx.executor()).unwrap();
742
743 let parent_id = session_id("parent-thread");
744 let child_id = session_id("child-thread");
745
746 let mut child_thread = make_thread(
747 "Subagent Thread",
748 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
749 );
750 child_thread.subagent_context = Some(crate::SubagentContext {
751 parent_thread_id: parent_id.clone(),
752 depth: 2,
753 });
754
755 database
756 .save_thread(child_id.clone(), child_thread, PathList::default())
757 .await
758 .unwrap();
759
760 let loaded = database
761 .load_thread(child_id)
762 .await
763 .unwrap()
764 .expect("thread should exist");
765
766 let context = loaded
767 .subagent_context
768 .expect("subagent_context should be restored");
769 assert_eq!(context.parent_thread_id, parent_id);
770 assert_eq!(context.depth, 2);
771 }
772
773 #[gpui::test]
774 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
775 let database = ThreadsDatabase::new(cx.executor()).unwrap();
776
777 let thread_id = session_id("regular-thread");
778 let thread = make_thread(
779 "Regular Thread",
780 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
781 );
782
783 database
784 .save_thread(thread_id.clone(), thread, PathList::default())
785 .await
786 .unwrap();
787
788 let loaded = database
789 .load_thread(thread_id)
790 .await
791 .unwrap()
792 .expect("thread should exist");
793
794 assert!(
795 loaded.subagent_context.is_none(),
796 "Regular threads should have no subagent_context"
797 );
798 }
799
800 #[gpui::test]
801 async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
802 let database = ThreadsDatabase::new(cx.executor()).unwrap();
803
804 let thread_id = session_id("folder-thread");
805 let thread = make_thread(
806 "Folder Thread",
807 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
808 );
809
810 let folder_paths = PathList::new(&[
811 std::path::PathBuf::from("/home/user/project-a"),
812 std::path::PathBuf::from("/home/user/project-b"),
813 ]);
814
815 database
816 .save_thread(thread_id.clone(), thread, folder_paths.clone())
817 .await
818 .unwrap();
819
820 let threads = database.list_threads().await.unwrap();
821 assert_eq!(threads.len(), 1);
822 assert_eq!(threads[0].folder_paths, folder_paths);
823 }
824
825 #[gpui::test]
826 async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
827 let database = ThreadsDatabase::new(cx.executor()).unwrap();
828
829 let thread_id = session_id("no-folder-thread");
830 let thread = make_thread(
831 "No Folder Thread",
832 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
833 );
834
835 database
836 .save_thread(thread_id.clone(), thread, PathList::default())
837 .await
838 .unwrap();
839
840 let threads = database.list_threads().await.unwrap();
841 assert_eq!(threads.len(), 1);
842 assert!(threads[0].folder_paths.is_empty());
843 }
844}