1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use language_model::Speed;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use util::path_list::PathList;
22use zed_env_vars::ZED_STATELESS;
23
24pub type DbMessage = crate::Message;
25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
27
28#[derive(Debug, Clone)]
29pub struct DbThreadMetadata {
30 pub id: acp::SessionId,
31 pub parent_session_id: Option<acp::SessionId>,
32 pub title: SharedString,
33 pub updated_at: DateTime<Utc>,
34 pub created_at: Option<DateTime<Utc>>,
35 /// The workspace folder paths this thread was created against, sorted
36 /// lexicographically. Used for grouping threads by project in the sidebar.
37 pub folder_paths: PathList,
38}
39
40impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
41 fn from(meta: &DbThreadMetadata) -> Self {
42 Self {
43 session_id: meta.id.clone(),
44 work_dirs: Some(meta.folder_paths.clone()),
45 title: Some(meta.title.clone()),
46 updated_at: Some(meta.updated_at),
47 created_at: meta.created_at,
48 meta: None,
49 }
50 }
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54pub struct DbThread {
55 pub title: SharedString,
56 pub messages: Vec<DbMessage>,
57 pub updated_at: DateTime<Utc>,
58 #[serde(default)]
59 pub detailed_summary: Option<SharedString>,
60 #[serde(default)]
61 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
62 #[serde(default)]
63 pub cumulative_token_usage: language_model::TokenUsage,
64 #[serde(default)]
65 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
66 #[serde(default)]
67 pub model: Option<DbLanguageModel>,
68 #[serde(default)]
69 pub profile: Option<AgentProfileId>,
70 #[serde(default)]
71 pub imported: bool,
72 #[serde(default)]
73 pub subagent_context: Option<crate::SubagentContext>,
74 #[serde(default)]
75 pub speed: Option<Speed>,
76 #[serde(default)]
77 pub thinking_enabled: bool,
78 #[serde(default)]
79 pub thinking_effort: Option<String>,
80 #[serde(default)]
81 pub draft_prompt: Option<Vec<acp::ContentBlock>>,
82 #[serde(default)]
83 pub ui_scroll_position: Option<SerializedScrollPosition>,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
87pub struct SerializedScrollPosition {
88 pub item_ix: usize,
89 pub offset_in_item: f32,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SharedThread {
94 pub title: SharedString,
95 pub messages: Vec<DbMessage>,
96 pub updated_at: DateTime<Utc>,
97 #[serde(default)]
98 pub model: Option<DbLanguageModel>,
99 pub version: String,
100}
101
102impl SharedThread {
103 pub const VERSION: &'static str = "1.0.0";
104
105 pub fn from_db_thread(thread: &DbThread) -> Self {
106 Self {
107 title: thread.title.clone(),
108 messages: thread.messages.clone(),
109 updated_at: thread.updated_at,
110 model: thread.model.clone(),
111 version: Self::VERSION.to_string(),
112 }
113 }
114
115 pub fn to_db_thread(self) -> DbThread {
116 DbThread {
117 title: format!("🔗 {}", self.title).into(),
118 messages: self.messages,
119 updated_at: self.updated_at,
120 detailed_summary: None,
121 initial_project_snapshot: None,
122 cumulative_token_usage: Default::default(),
123 request_token_usage: Default::default(),
124 model: self.model,
125 profile: None,
126 imported: true,
127 subagent_context: None,
128 speed: None,
129 thinking_enabled: false,
130 thinking_effort: None,
131 draft_prompt: None,
132 ui_scroll_position: None,
133 }
134 }
135
136 pub fn to_bytes(&self) -> Result<Vec<u8>> {
137 const COMPRESSION_LEVEL: i32 = 3;
138 let json = serde_json::to_vec(self)?;
139 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
140 Ok(compressed)
141 }
142
143 pub fn from_bytes(data: &[u8]) -> Result<Self> {
144 let decompressed = zstd::decode_all(data)?;
145 Ok(serde_json::from_slice(&decompressed)?)
146 }
147}
148
149impl DbThread {
150 pub const VERSION: &'static str = "0.3.0";
151
152 pub fn from_json(json: &[u8]) -> Result<Self> {
153 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
154 match saved_thread_json.get("version") {
155 Some(serde_json::Value::String(version)) => match version.as_str() {
156 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
157 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
158 json,
159 )?),
160 },
161 _ => {
162 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
163 }
164 }
165 }
166
167 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
168 let mut messages = Vec::new();
169 let mut request_token_usage = HashMap::default();
170
171 let mut last_user_message_id = None;
172 for (ix, msg) in thread.messages.into_iter().enumerate() {
173 let message = match msg.role {
174 language_model::Role::User => {
175 let mut content = Vec::new();
176
177 // Convert segments to content
178 for segment in msg.segments {
179 match segment {
180 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
181 content.push(UserMessageContent::Text(text));
182 }
183 crate::legacy_thread::SerializedMessageSegment::Thinking {
184 text,
185 ..
186 } => {
187 // User messages don't have thinking segments, but handle gracefully
188 content.push(UserMessageContent::Text(text));
189 }
190 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
191 ..
192 } => {
193 // User messages don't have redacted thinking, skip.
194 }
195 }
196 }
197
198 // If no content was added, add context as text if available
199 if content.is_empty() && !msg.context.is_empty() {
200 content.push(UserMessageContent::Text(msg.context));
201 }
202
203 let id = UserMessageId::new();
204 last_user_message_id = Some(id.clone());
205
206 crate::Message::User(UserMessage {
207 // MessageId from old format can't be meaningfully converted, so generate a new one
208 id,
209 content,
210 })
211 }
212 language_model::Role::Assistant => {
213 let mut content = Vec::new();
214
215 // Convert segments to content
216 for segment in msg.segments {
217 match segment {
218 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
219 content.push(AgentMessageContent::Text(text));
220 }
221 crate::legacy_thread::SerializedMessageSegment::Thinking {
222 text,
223 signature,
224 } => {
225 content.push(AgentMessageContent::Thinking { text, signature });
226 }
227 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
228 data,
229 } => {
230 content.push(AgentMessageContent::RedactedThinking(data));
231 }
232 }
233 }
234
235 // Convert tool uses
236 let mut tool_names_by_id = HashMap::default();
237 for tool_use in msg.tool_uses {
238 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
239 content.push(AgentMessageContent::ToolUse(
240 language_model::LanguageModelToolUse {
241 id: tool_use.id,
242 name: tool_use.name.into(),
243 raw_input: serde_json::to_string(&tool_use.input)
244 .unwrap_or_default(),
245 input: tool_use.input,
246 is_input_complete: true,
247 thought_signature: None,
248 },
249 ));
250 }
251
252 // Convert tool results
253 let mut tool_results = IndexMap::default();
254 for tool_result in msg.tool_results {
255 let name = tool_names_by_id
256 .remove(&tool_result.tool_use_id)
257 .unwrap_or_else(|| SharedString::from("unknown"));
258 tool_results.insert(
259 tool_result.tool_use_id.clone(),
260 language_model::LanguageModelToolResult {
261 tool_use_id: tool_result.tool_use_id,
262 tool_name: name.into(),
263 is_error: tool_result.is_error,
264 content: tool_result.content,
265 output: tool_result.output,
266 },
267 );
268 }
269
270 if let Some(last_user_message_id) = &last_user_message_id
271 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
272 {
273 request_token_usage.insert(last_user_message_id.clone(), token_usage);
274 }
275
276 crate::Message::Agent(AgentMessage {
277 content,
278 tool_results,
279 reasoning_details: None,
280 })
281 }
282 language_model::Role::System => {
283 // Skip system messages as they're not supported in the new format
284 continue;
285 }
286 };
287
288 messages.push(message);
289 }
290
291 Ok(Self {
292 title: thread.summary,
293 messages,
294 updated_at: thread.updated_at,
295 detailed_summary: match thread.detailed_summary_state {
296 crate::legacy_thread::DetailedSummaryState::NotGenerated
297 | crate::legacy_thread::DetailedSummaryState::Generating => None,
298 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
299 },
300 initial_project_snapshot: thread.initial_project_snapshot,
301 cumulative_token_usage: thread.cumulative_token_usage,
302 request_token_usage,
303 model: thread.model,
304 profile: thread.profile,
305 imported: false,
306 subagent_context: None,
307 speed: None,
308 thinking_enabled: false,
309 thinking_effort: None,
310 draft_prompt: None,
311 ui_scroll_position: None,
312 })
313 }
314}
315
316#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
317pub enum DataType {
318 #[serde(rename = "json")]
319 Json,
320 #[serde(rename = "zstd")]
321 Zstd,
322}
323
324impl Bind for DataType {
325 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
326 let value = match self {
327 DataType::Json => "json",
328 DataType::Zstd => "zstd",
329 };
330 value.bind(statement, start_index)
331 }
332}
333
334impl Column for DataType {
335 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
336 let (value, next_index) = String::column(statement, start_index)?;
337 let data_type = match value.as_str() {
338 "json" => DataType::Json,
339 "zstd" => DataType::Zstd,
340 _ => anyhow::bail!("Unknown data type: {}", value),
341 };
342 Ok((data_type, next_index))
343 }
344}
345
346pub(crate) struct ThreadsDatabase {
347 executor: BackgroundExecutor,
348 connection: Arc<Mutex<Connection>>,
349}
350
351struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
352
353impl Global for GlobalThreadsDatabase {}
354
355impl ThreadsDatabase {
356 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
357 if cx.has_global::<GlobalThreadsDatabase>() {
358 return cx.global::<GlobalThreadsDatabase>().0.clone();
359 }
360 let executor = cx.background_executor().clone();
361 let task = executor
362 .spawn({
363 let executor = executor.clone();
364 async move {
365 match ThreadsDatabase::new(executor) {
366 Ok(db) => Ok(Arc::new(db)),
367 Err(err) => Err(Arc::new(err)),
368 }
369 }
370 })
371 .shared();
372
373 cx.set_global(GlobalThreadsDatabase(task.clone()));
374 task
375 }
376
377 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
378 let connection = if *ZED_STATELESS {
379 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
380 } else if cfg!(any(feature = "test-support", test)) {
381 // rust stores the name of the test on the current thread.
382 // We use this to automatically create a database that will
383 // be shared within the test (for the test_retrieve_old_thread)
384 // but not with concurrent tests.
385 let thread = std::thread::current();
386 let test_name = thread.name();
387 Connection::open_memory(Some(&format!(
388 "THREAD_FALLBACK_{}",
389 test_name.unwrap_or_default()
390 )))
391 } else {
392 let threads_dir = paths::data_dir().join("threads");
393 std::fs::create_dir_all(&threads_dir)?;
394 let sqlite_path = threads_dir.join("threads.db");
395 Connection::open_file(&sqlite_path.to_string_lossy())
396 };
397
398 connection.exec(indoc! {"
399 CREATE TABLE IF NOT EXISTS threads (
400 id TEXT PRIMARY KEY,
401 summary TEXT NOT NULL,
402 updated_at TEXT NOT NULL,
403 data_type TEXT NOT NULL,
404 data BLOB NOT NULL
405 )
406 "})?()
407 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
408
409 if let Ok(mut s) = connection.exec(indoc! {"
410 ALTER TABLE threads ADD COLUMN parent_id TEXT
411 "})
412 {
413 s().ok();
414 }
415
416 if let Ok(mut s) = connection.exec(indoc! {"
417 ALTER TABLE threads ADD COLUMN folder_paths TEXT;
418 ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
419 "})
420 {
421 s().ok();
422 }
423
424 if let Ok(mut s) = connection.exec(indoc! {"
425 ALTER TABLE threads ADD COLUMN created_at TEXT;
426 "})
427 {
428 if s().is_ok() {
429 connection.exec(indoc! {"
430 UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
431 "})?()?;
432 }
433 }
434
435 let db = Self {
436 executor,
437 connection: Arc::new(Mutex::new(connection)),
438 };
439
440 Ok(db)
441 }
442
443 fn save_thread_sync(
444 connection: &Arc<Mutex<Connection>>,
445 id: acp::SessionId,
446 thread: DbThread,
447 folder_paths: &PathList,
448 ) -> Result<()> {
449 const COMPRESSION_LEVEL: i32 = 3;
450
451 #[derive(Serialize)]
452 struct SerializedThread {
453 #[serde(flatten)]
454 thread: DbThread,
455 version: &'static str,
456 }
457
458 let title = thread.title.to_string();
459 let updated_at = thread.updated_at.to_rfc3339();
460 let parent_id = thread
461 .subagent_context
462 .as_ref()
463 .map(|ctx| ctx.parent_thread_id.0.clone());
464 let serialized_folder_paths = folder_paths.serialize();
465 let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
466 if folder_paths.is_empty() {
467 (None, None)
468 } else {
469 (
470 Some(serialized_folder_paths.paths),
471 Some(serialized_folder_paths.order),
472 )
473 };
474 let json_data = serde_json::to_string(&SerializedThread {
475 thread,
476 version: DbThread::VERSION,
477 })?;
478
479 let connection = connection.lock();
480
481 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
482 let data_type = DataType::Zstd;
483 let data = compressed;
484
485 // Use the thread's updated_at as created_at for new threads.
486 // This ensures the creation time reflects when the thread was conceptually
487 // created, not when it was saved to the database.
488 let created_at = updated_at.clone();
489
490 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
491 INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
492 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
493 ON CONFLICT(id) DO UPDATE SET
494 parent_id = excluded.parent_id,
495 folder_paths = excluded.folder_paths,
496 folder_paths_order = excluded.folder_paths_order,
497 summary = excluded.summary,
498 updated_at = excluded.updated_at,
499 data_type = excluded.data_type,
500 data = excluded.data
501 "})?;
502
503 insert((
504 id.0,
505 parent_id,
506 folder_paths_str,
507 folder_paths_order_str,
508 title,
509 updated_at,
510 data_type,
511 data,
512 created_at,
513 ))?;
514
515 Ok(())
516 }
517
518 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
519 let connection = self.connection.clone();
520
521 self.executor.spawn(async move {
522 let connection = connection.lock();
523
524 let mut select = connection
525 .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
526 SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
527 "})?;
528
529 let rows = select(())?;
530 let mut threads = Vec::new();
531
532 for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
533 let folder_paths = folder_paths
534 .map(|paths| {
535 PathList::deserialize(&util::path_list::SerializedPathList {
536 paths,
537 order: folder_paths_order.unwrap_or_default(),
538 })
539 })
540 .unwrap_or_default();
541 let created_at = created_at
542 .as_deref()
543 .map(DateTime::parse_from_rfc3339)
544 .transpose()?
545 .map(|dt| dt.with_timezone(&Utc));
546
547 threads.push(DbThreadMetadata {
548 id: acp::SessionId::new(id),
549 parent_session_id: parent_id.map(acp::SessionId::new),
550 title: summary.into(),
551 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
552 created_at,
553 folder_paths,
554 });
555 }
556
557 Ok(threads)
558 })
559 }
560
561 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
562 let connection = self.connection.clone();
563
564 self.executor.spawn(async move {
565 let connection = connection.lock();
566 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
567 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
568 "})?;
569
570 let rows = select(id.0)?;
571 if let Some((data_type, data)) = rows.into_iter().next() {
572 let json_data = match data_type {
573 DataType::Zstd => {
574 let decompressed = zstd::decode_all(&data[..])?;
575 String::from_utf8(decompressed)?
576 }
577 DataType::Json => String::from_utf8(data)?,
578 };
579 let thread = DbThread::from_json(json_data.as_bytes())?;
580 Ok(Some(thread))
581 } else {
582 Ok(None)
583 }
584 })
585 }
586
587 pub fn save_thread(
588 &self,
589 id: acp::SessionId,
590 thread: DbThread,
591 folder_paths: PathList,
592 ) -> Task<Result<()>> {
593 let connection = self.connection.clone();
594
595 self.executor
596 .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
597 }
598
599 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
600 let connection = self.connection.clone();
601
602 self.executor.spawn(async move {
603 let connection = connection.lock();
604
605 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
606 DELETE FROM threads WHERE id = ?
607 "})?;
608
609 delete(id.0)?;
610
611 Ok(())
612 })
613 }
614
615 pub fn delete_threads(&self) -> Task<Result<()>> {
616 let connection = self.connection.clone();
617
618 self.executor.spawn(async move {
619 let connection = connection.lock();
620
621 let mut delete = connection.exec_bound::<()>(indoc! {"
622 DELETE FROM threads
623 "})?;
624
625 delete(())?;
626
627 Ok(())
628 })
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use chrono::{DateTime, TimeZone, Utc};
636 use collections::HashMap;
637 use gpui::TestAppContext;
638 use std::sync::Arc;
639
640 #[test]
641 fn test_shared_thread_roundtrip() {
642 let original = SharedThread {
643 title: "Test Thread".into(),
644 messages: vec![],
645 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
646 model: None,
647 version: SharedThread::VERSION.to_string(),
648 };
649
650 let bytes = original.to_bytes().expect("Failed to serialize");
651 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
652
653 assert_eq!(restored.title, original.title);
654 assert_eq!(restored.version, original.version);
655 assert_eq!(restored.updated_at, original.updated_at);
656 }
657
658 #[test]
659 fn test_imported_flag_defaults_to_false() {
660 // Simulate deserializing a thread without the imported field (backwards compatibility).
661 let json = r#"{
662 "title": "Old Thread",
663 "messages": [],
664 "updated_at": "2024-01-01T00:00:00Z"
665 }"#;
666
667 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
668
669 assert!(
670 !db_thread.imported,
671 "Legacy threads without imported field should default to false"
672 );
673 }
674
675 fn session_id(value: &str) -> acp::SessionId {
676 acp::SessionId::new(Arc::<str>::from(value))
677 }
678
679 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
680 DbThread {
681 title: title.to_string().into(),
682 messages: Vec::new(),
683 updated_at,
684 detailed_summary: None,
685 initial_project_snapshot: None,
686 cumulative_token_usage: Default::default(),
687 request_token_usage: HashMap::default(),
688 model: None,
689 profile: None,
690 imported: false,
691 subagent_context: None,
692 speed: None,
693 thinking_enabled: false,
694 thinking_effort: None,
695 draft_prompt: None,
696 ui_scroll_position: None,
697 }
698 }
699
700 #[gpui::test]
701 async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
702 let database = ThreadsDatabase::new(cx.executor()).unwrap();
703
704 let older_id = session_id("thread-a");
705 let newer_id = session_id("thread-b");
706
707 let older_thread = make_thread(
708 "Thread A",
709 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
710 );
711 let newer_thread = make_thread(
712 "Thread B",
713 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
714 );
715
716 database
717 .save_thread(older_id.clone(), older_thread, PathList::default())
718 .await
719 .unwrap();
720 database
721 .save_thread(newer_id.clone(), newer_thread, PathList::default())
722 .await
723 .unwrap();
724
725 let entries = database.list_threads().await.unwrap();
726 assert_eq!(entries.len(), 2);
727 assert_eq!(entries[0].id, newer_id);
728 assert_eq!(entries[1].id, older_id);
729 }
730
731 #[gpui::test]
732 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
733 let database = ThreadsDatabase::new(cx.executor()).unwrap();
734
735 let thread_id = session_id("thread-a");
736 let original_thread = make_thread(
737 "Thread A",
738 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
739 );
740 let updated_thread = make_thread(
741 "Thread B",
742 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
743 );
744
745 database
746 .save_thread(thread_id.clone(), original_thread, PathList::default())
747 .await
748 .unwrap();
749 database
750 .save_thread(thread_id.clone(), updated_thread, PathList::default())
751 .await
752 .unwrap();
753
754 let entries = database.list_threads().await.unwrap();
755 assert_eq!(entries.len(), 1);
756 assert_eq!(entries[0].id, thread_id);
757 assert_eq!(entries[0].title.as_ref(), "Thread B");
758 assert_eq!(
759 entries[0].updated_at,
760 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
761 );
762 assert!(
763 entries[0].created_at.is_some(),
764 "created_at should be populated"
765 );
766 }
767
768 #[test]
769 fn test_subagent_context_defaults_to_none() {
770 let json = r#"{
771 "title": "Old Thread",
772 "messages": [],
773 "updated_at": "2024-01-01T00:00:00Z"
774 }"#;
775
776 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
777
778 assert!(
779 db_thread.subagent_context.is_none(),
780 "Legacy threads without subagent_context should default to None"
781 );
782 }
783
784 #[test]
785 fn test_draft_prompt_defaults_to_none() {
786 let json = r#"{
787 "title": "Old Thread",
788 "messages": [],
789 "updated_at": "2024-01-01T00:00:00Z"
790 }"#;
791
792 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
793
794 assert!(
795 db_thread.draft_prompt.is_none(),
796 "Legacy threads without draft_prompt field should default to None"
797 );
798 }
799
800 #[gpui::test]
801 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
802 let database = ThreadsDatabase::new(cx.executor()).unwrap();
803
804 let parent_id = session_id("parent-thread");
805 let child_id = session_id("child-thread");
806
807 let mut child_thread = make_thread(
808 "Subagent Thread",
809 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
810 );
811 child_thread.subagent_context = Some(crate::SubagentContext {
812 parent_thread_id: parent_id.clone(),
813 depth: 2,
814 });
815
816 database
817 .save_thread(child_id.clone(), child_thread, PathList::default())
818 .await
819 .unwrap();
820
821 let loaded = database
822 .load_thread(child_id)
823 .await
824 .unwrap()
825 .expect("thread should exist");
826
827 let context = loaded
828 .subagent_context
829 .expect("subagent_context should be restored");
830 assert_eq!(context.parent_thread_id, parent_id);
831 assert_eq!(context.depth, 2);
832 }
833
834 #[gpui::test]
835 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
836 let database = ThreadsDatabase::new(cx.executor()).unwrap();
837
838 let thread_id = session_id("regular-thread");
839 let thread = make_thread(
840 "Regular Thread",
841 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
842 );
843
844 database
845 .save_thread(thread_id.clone(), thread, PathList::default())
846 .await
847 .unwrap();
848
849 let loaded = database
850 .load_thread(thread_id)
851 .await
852 .unwrap()
853 .expect("thread should exist");
854
855 assert!(
856 loaded.subagent_context.is_none(),
857 "Regular threads should have no subagent_context"
858 );
859 }
860
861 #[gpui::test]
862 async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
863 let database = ThreadsDatabase::new(cx.executor()).unwrap();
864
865 let thread_id = session_id("folder-thread");
866 let thread = make_thread(
867 "Folder Thread",
868 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
869 );
870
871 let folder_paths = PathList::new(&[
872 std::path::PathBuf::from("/home/user/project-a"),
873 std::path::PathBuf::from("/home/user/project-b"),
874 ]);
875
876 database
877 .save_thread(thread_id.clone(), thread, folder_paths.clone())
878 .await
879 .unwrap();
880
881 let threads = database.list_threads().await.unwrap();
882 assert_eq!(threads.len(), 1);
883 }
884
885 #[gpui::test]
886 async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
887 let database = ThreadsDatabase::new(cx.executor()).unwrap();
888
889 let thread_id = session_id("no-folder-thread");
890 let thread = make_thread(
891 "No Folder Thread",
892 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
893 );
894
895 database
896 .save_thread(thread_id.clone(), thread, PathList::default())
897 .await
898 .unwrap();
899
900 let threads = database.list_threads().await.unwrap();
901 assert_eq!(threads.len(), 1);
902 }
903
904 #[test]
905 fn test_scroll_position_defaults_to_none() {
906 let json = r#"{
907 "title": "Old Thread",
908 "messages": [],
909 "updated_at": "2024-01-01T00:00:00Z"
910 }"#;
911
912 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
913
914 assert!(
915 db_thread.ui_scroll_position.is_none(),
916 "Legacy threads without scroll_position field should default to None"
917 );
918 }
919
920 #[gpui::test]
921 async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
922 let database = ThreadsDatabase::new(cx.executor()).unwrap();
923
924 let thread_id = session_id("thread-with-scroll");
925
926 let mut thread = make_thread(
927 "Thread With Scroll",
928 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
929 );
930 thread.ui_scroll_position = Some(SerializedScrollPosition {
931 item_ix: 42,
932 offset_in_item: 13.5,
933 });
934
935 database
936 .save_thread(thread_id.clone(), thread, PathList::default())
937 .await
938 .unwrap();
939
940 let loaded = database
941 .load_thread(thread_id)
942 .await
943 .unwrap()
944 .expect("thread should exist");
945
946 let scroll = loaded
947 .ui_scroll_position
948 .expect("scroll_position should be restored");
949 assert_eq!(scroll.item_ix, 42);
950 assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
951 }
952}