1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use language_model::Speed;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use util::path_list::PathList;
22use zed_env_vars::ZED_STATELESS;
23
24pub type DbMessage = crate::Message;
25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct DbThreadMetadata {
30 pub id: acp::SessionId,
31 pub parent_session_id: Option<acp::SessionId>,
32 #[serde(alias = "summary")]
33 pub title: SharedString,
34 pub updated_at: DateTime<Utc>,
35 pub created_at: Option<DateTime<Utc>>,
36 /// The workspace folder paths this thread was created against, sorted
37 /// lexicographically. Used for grouping threads by project in the sidebar.
38 pub folder_paths: PathList,
39}
40
41impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
42 fn from(meta: &DbThreadMetadata) -> Self {
43 Self {
44 session_id: meta.id.clone(),
45 cwd: None,
46 title: Some(meta.title.clone()),
47 updated_at: Some(meta.updated_at),
48 created_at: meta.created_at,
49 meta: None,
50 }
51 }
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55pub struct DbThread {
56 pub title: SharedString,
57 pub messages: Vec<DbMessage>,
58 pub updated_at: DateTime<Utc>,
59 #[serde(default)]
60 pub detailed_summary: Option<SharedString>,
61 #[serde(default)]
62 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
63 #[serde(default)]
64 pub cumulative_token_usage: language_model::TokenUsage,
65 #[serde(default)]
66 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
67 #[serde(default)]
68 pub model: Option<DbLanguageModel>,
69 #[serde(default)]
70 pub profile: Option<AgentProfileId>,
71 #[serde(default)]
72 pub imported: bool,
73 #[serde(default)]
74 pub subagent_context: Option<crate::SubagentContext>,
75 #[serde(default)]
76 pub speed: Option<Speed>,
77 #[serde(default)]
78 pub thinking_enabled: bool,
79 #[serde(default)]
80 pub thinking_effort: Option<String>,
81 #[serde(default)]
82 pub draft_prompt: Option<Vec<acp::ContentBlock>>,
83 #[serde(default)]
84 pub ui_scroll_position: Option<SerializedScrollPosition>,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
88pub struct SerializedScrollPosition {
89 pub item_ix: usize,
90 pub offset_in_item: f32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct SharedThread {
95 pub title: SharedString,
96 pub messages: Vec<DbMessage>,
97 pub updated_at: DateTime<Utc>,
98 #[serde(default)]
99 pub model: Option<DbLanguageModel>,
100 pub version: String,
101}
102
103impl SharedThread {
104 pub const VERSION: &'static str = "1.0.0";
105
106 pub fn from_db_thread(thread: &DbThread) -> Self {
107 Self {
108 title: thread.title.clone(),
109 messages: thread.messages.clone(),
110 updated_at: thread.updated_at,
111 model: thread.model.clone(),
112 version: Self::VERSION.to_string(),
113 }
114 }
115
116 pub fn to_db_thread(self) -> DbThread {
117 DbThread {
118 title: format!("🔗 {}", self.title).into(),
119 messages: self.messages,
120 updated_at: self.updated_at,
121 detailed_summary: None,
122 initial_project_snapshot: None,
123 cumulative_token_usage: Default::default(),
124 request_token_usage: Default::default(),
125 model: self.model,
126 profile: None,
127 imported: true,
128 subagent_context: None,
129 speed: None,
130 thinking_enabled: false,
131 thinking_effort: None,
132 draft_prompt: None,
133 ui_scroll_position: None,
134 }
135 }
136
137 pub fn to_bytes(&self) -> Result<Vec<u8>> {
138 const COMPRESSION_LEVEL: i32 = 3;
139 let json = serde_json::to_vec(self)?;
140 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
141 Ok(compressed)
142 }
143
144 pub fn from_bytes(data: &[u8]) -> Result<Self> {
145 let decompressed = zstd::decode_all(data)?;
146 Ok(serde_json::from_slice(&decompressed)?)
147 }
148}
149
150impl DbThread {
151 pub const VERSION: &'static str = "0.3.0";
152
153 pub fn from_json(json: &[u8]) -> Result<Self> {
154 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
155 match saved_thread_json.get("version") {
156 Some(serde_json::Value::String(version)) => match version.as_str() {
157 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
158 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
159 json,
160 )?),
161 },
162 _ => {
163 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
164 }
165 }
166 }
167
168 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
169 let mut messages = Vec::new();
170 let mut request_token_usage = HashMap::default();
171
172 let mut last_user_message_id = None;
173 for (ix, msg) in thread.messages.into_iter().enumerate() {
174 let message = match msg.role {
175 language_model::Role::User => {
176 let mut content = Vec::new();
177
178 // Convert segments to content
179 for segment in msg.segments {
180 match segment {
181 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
182 content.push(UserMessageContent::Text(text));
183 }
184 crate::legacy_thread::SerializedMessageSegment::Thinking {
185 text,
186 ..
187 } => {
188 // User messages don't have thinking segments, but handle gracefully
189 content.push(UserMessageContent::Text(text));
190 }
191 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
192 ..
193 } => {
194 // User messages don't have redacted thinking, skip.
195 }
196 }
197 }
198
199 // If no content was added, add context as text if available
200 if content.is_empty() && !msg.context.is_empty() {
201 content.push(UserMessageContent::Text(msg.context));
202 }
203
204 let id = UserMessageId::new();
205 last_user_message_id = Some(id.clone());
206
207 crate::Message::User(UserMessage {
208 // MessageId from old format can't be meaningfully converted, so generate a new one
209 id,
210 content,
211 })
212 }
213 language_model::Role::Assistant => {
214 let mut content = Vec::new();
215
216 // Convert segments to content
217 for segment in msg.segments {
218 match segment {
219 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
220 content.push(AgentMessageContent::Text(text));
221 }
222 crate::legacy_thread::SerializedMessageSegment::Thinking {
223 text,
224 signature,
225 } => {
226 content.push(AgentMessageContent::Thinking { text, signature });
227 }
228 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
229 data,
230 } => {
231 content.push(AgentMessageContent::RedactedThinking(data));
232 }
233 }
234 }
235
236 // Convert tool uses
237 let mut tool_names_by_id = HashMap::default();
238 for tool_use in msg.tool_uses {
239 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
240 content.push(AgentMessageContent::ToolUse(
241 language_model::LanguageModelToolUse {
242 id: tool_use.id,
243 name: tool_use.name.into(),
244 raw_input: serde_json::to_string(&tool_use.input)
245 .unwrap_or_default(),
246 input: tool_use.input,
247 is_input_complete: true,
248 thought_signature: None,
249 },
250 ));
251 }
252
253 // Convert tool results
254 let mut tool_results = IndexMap::default();
255 for tool_result in msg.tool_results {
256 let name = tool_names_by_id
257 .remove(&tool_result.tool_use_id)
258 .unwrap_or_else(|| SharedString::from("unknown"));
259 tool_results.insert(
260 tool_result.tool_use_id.clone(),
261 language_model::LanguageModelToolResult {
262 tool_use_id: tool_result.tool_use_id,
263 tool_name: name.into(),
264 is_error: tool_result.is_error,
265 content: tool_result.content,
266 output: tool_result.output,
267 },
268 );
269 }
270
271 if let Some(last_user_message_id) = &last_user_message_id
272 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
273 {
274 request_token_usage.insert(last_user_message_id.clone(), token_usage);
275 }
276
277 crate::Message::Agent(AgentMessage {
278 content,
279 tool_results,
280 reasoning_details: None,
281 })
282 }
283 language_model::Role::System => {
284 // Skip system messages as they're not supported in the new format
285 continue;
286 }
287 };
288
289 messages.push(message);
290 }
291
292 Ok(Self {
293 title: thread.summary,
294 messages,
295 updated_at: thread.updated_at,
296 detailed_summary: match thread.detailed_summary_state {
297 crate::legacy_thread::DetailedSummaryState::NotGenerated
298 | crate::legacy_thread::DetailedSummaryState::Generating => None,
299 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
300 },
301 initial_project_snapshot: thread.initial_project_snapshot,
302 cumulative_token_usage: thread.cumulative_token_usage,
303 request_token_usage,
304 model: thread.model,
305 profile: thread.profile,
306 imported: false,
307 subagent_context: None,
308 speed: None,
309 thinking_enabled: false,
310 thinking_effort: None,
311 draft_prompt: None,
312 ui_scroll_position: None,
313 })
314 }
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
318pub enum DataType {
319 #[serde(rename = "json")]
320 Json,
321 #[serde(rename = "zstd")]
322 Zstd,
323}
324
325impl Bind for DataType {
326 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
327 let value = match self {
328 DataType::Json => "json",
329 DataType::Zstd => "zstd",
330 };
331 value.bind(statement, start_index)
332 }
333}
334
335impl Column for DataType {
336 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
337 let (value, next_index) = String::column(statement, start_index)?;
338 let data_type = match value.as_str() {
339 "json" => DataType::Json,
340 "zstd" => DataType::Zstd,
341 _ => anyhow::bail!("Unknown data type: {}", value),
342 };
343 Ok((data_type, next_index))
344 }
345}
346
347pub(crate) struct ThreadsDatabase {
348 executor: BackgroundExecutor,
349 connection: Arc<Mutex<Connection>>,
350}
351
352struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
353
354impl Global for GlobalThreadsDatabase {}
355
356impl ThreadsDatabase {
357 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
358 if cx.has_global::<GlobalThreadsDatabase>() {
359 return cx.global::<GlobalThreadsDatabase>().0.clone();
360 }
361 let executor = cx.background_executor().clone();
362 let task = executor
363 .spawn({
364 let executor = executor.clone();
365 async move {
366 match ThreadsDatabase::new(executor) {
367 Ok(db) => Ok(Arc::new(db)),
368 Err(err) => Err(Arc::new(err)),
369 }
370 }
371 })
372 .shared();
373
374 cx.set_global(GlobalThreadsDatabase(task.clone()));
375 task
376 }
377
378 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
379 let connection = if *ZED_STATELESS {
380 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
381 } else if cfg!(any(feature = "test-support", test)) {
382 // rust stores the name of the test on the current thread.
383 // We use this to automatically create a database that will
384 // be shared within the test (for the test_retrieve_old_thread)
385 // but not with concurrent tests.
386 let thread = std::thread::current();
387 let test_name = thread.name();
388 Connection::open_memory(Some(&format!(
389 "THREAD_FALLBACK_{}",
390 test_name.unwrap_or_default()
391 )))
392 } else {
393 let threads_dir = paths::data_dir().join("threads");
394 std::fs::create_dir_all(&threads_dir)?;
395 let sqlite_path = threads_dir.join("threads.db");
396 Connection::open_file(&sqlite_path.to_string_lossy())
397 };
398
399 connection.exec(indoc! {"
400 CREATE TABLE IF NOT EXISTS threads (
401 id TEXT PRIMARY KEY,
402 summary TEXT NOT NULL,
403 updated_at TEXT NOT NULL,
404 data_type TEXT NOT NULL,
405 data BLOB NOT NULL
406 )
407 "})?()
408 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
409
410 if let Ok(mut s) = connection.exec(indoc! {"
411 ALTER TABLE threads ADD COLUMN parent_id TEXT
412 "})
413 {
414 s().ok();
415 }
416
417 if let Ok(mut s) = connection.exec(indoc! {"
418 ALTER TABLE threads ADD COLUMN folder_paths TEXT;
419 ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
420 "})
421 {
422 s().ok();
423 }
424
425 if let Ok(mut s) = connection.exec(indoc! {"
426 ALTER TABLE threads ADD COLUMN created_at TEXT;
427 "})
428 {
429 if s().is_ok() {
430 connection.exec(indoc! {"
431 UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
432 "})?()?;
433 }
434 }
435
436 let db = Self {
437 executor,
438 connection: Arc::new(Mutex::new(connection)),
439 };
440
441 Ok(db)
442 }
443
444 fn save_thread_sync(
445 connection: &Arc<Mutex<Connection>>,
446 id: acp::SessionId,
447 thread: DbThread,
448 folder_paths: &PathList,
449 ) -> Result<()> {
450 const COMPRESSION_LEVEL: i32 = 3;
451
452 #[derive(Serialize)]
453 struct SerializedThread {
454 #[serde(flatten)]
455 thread: DbThread,
456 version: &'static str,
457 }
458
459 let title = thread.title.to_string();
460 let updated_at = thread.updated_at.to_rfc3339();
461 let parent_id = thread
462 .subagent_context
463 .as_ref()
464 .map(|ctx| ctx.parent_thread_id.0.clone());
465 let serialized_folder_paths = folder_paths.serialize();
466 let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
467 if folder_paths.is_empty() {
468 (None, None)
469 } else {
470 (
471 Some(serialized_folder_paths.paths),
472 Some(serialized_folder_paths.order),
473 )
474 };
475 let json_data = serde_json::to_string(&SerializedThread {
476 thread,
477 version: DbThread::VERSION,
478 })?;
479
480 let connection = connection.lock();
481
482 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
483 let data_type = DataType::Zstd;
484 let data = compressed;
485
486 // Use the thread's updated_at as created_at for new threads.
487 // This ensures the creation time reflects when the thread was conceptually
488 // created, not when it was saved to the database.
489 let created_at = updated_at.clone();
490
491 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
492 INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
493 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
494 ON CONFLICT(id) DO UPDATE SET
495 parent_id = excluded.parent_id,
496 folder_paths = excluded.folder_paths,
497 folder_paths_order = excluded.folder_paths_order,
498 summary = excluded.summary,
499 updated_at = excluded.updated_at,
500 data_type = excluded.data_type,
501 data = excluded.data
502 "})?;
503
504 insert((
505 id.0,
506 parent_id,
507 folder_paths_str,
508 folder_paths_order_str,
509 title,
510 updated_at,
511 data_type,
512 data,
513 created_at,
514 ))?;
515
516 Ok(())
517 }
518
519 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
520 let connection = self.connection.clone();
521
522 self.executor.spawn(async move {
523 let connection = connection.lock();
524
525 let mut select = connection
526 .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
527 SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
528 "})?;
529
530 let rows = select(())?;
531 let mut threads = Vec::new();
532
533 for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
534 let folder_paths = folder_paths
535 .map(|paths| {
536 PathList::deserialize(&util::path_list::SerializedPathList {
537 paths,
538 order: folder_paths_order.unwrap_or_default(),
539 })
540 })
541 .unwrap_or_default();
542 let created_at = created_at
543 .as_deref()
544 .map(DateTime::parse_from_rfc3339)
545 .transpose()?
546 .map(|dt| dt.with_timezone(&Utc));
547
548 threads.push(DbThreadMetadata {
549 id: acp::SessionId::new(id),
550 parent_session_id: parent_id.map(acp::SessionId::new),
551 title: summary.into(),
552 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
553 created_at,
554 folder_paths,
555 });
556 }
557
558 Ok(threads)
559 })
560 }
561
562 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
563 let connection = self.connection.clone();
564
565 self.executor.spawn(async move {
566 let connection = connection.lock();
567 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
568 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
569 "})?;
570
571 let rows = select(id.0)?;
572 if let Some((data_type, data)) = rows.into_iter().next() {
573 let json_data = match data_type {
574 DataType::Zstd => {
575 let decompressed = zstd::decode_all(&data[..])?;
576 String::from_utf8(decompressed)?
577 }
578 DataType::Json => String::from_utf8(data)?,
579 };
580 let thread = DbThread::from_json(json_data.as_bytes())?;
581 Ok(Some(thread))
582 } else {
583 Ok(None)
584 }
585 })
586 }
587
588 pub fn save_thread(
589 &self,
590 id: acp::SessionId,
591 thread: DbThread,
592 folder_paths: PathList,
593 ) -> Task<Result<()>> {
594 let connection = self.connection.clone();
595
596 self.executor
597 .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
598 }
599
600 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
601 let connection = self.connection.clone();
602
603 self.executor.spawn(async move {
604 let connection = connection.lock();
605
606 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
607 DELETE FROM threads WHERE id = ?
608 "})?;
609
610 delete(id.0)?;
611
612 Ok(())
613 })
614 }
615
616 pub fn delete_threads(&self) -> Task<Result<()>> {
617 let connection = self.connection.clone();
618
619 self.executor.spawn(async move {
620 let connection = connection.lock();
621
622 let mut delete = connection.exec_bound::<()>(indoc! {"
623 DELETE FROM threads
624 "})?;
625
626 delete(())?;
627
628 Ok(())
629 })
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use chrono::{DateTime, TimeZone, Utc};
637 use collections::HashMap;
638 use gpui::TestAppContext;
639 use std::sync::Arc;
640
641 #[test]
642 fn test_shared_thread_roundtrip() {
643 let original = SharedThread {
644 title: "Test Thread".into(),
645 messages: vec![],
646 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
647 model: None,
648 version: SharedThread::VERSION.to_string(),
649 };
650
651 let bytes = original.to_bytes().expect("Failed to serialize");
652 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
653
654 assert_eq!(restored.title, original.title);
655 assert_eq!(restored.version, original.version);
656 assert_eq!(restored.updated_at, original.updated_at);
657 }
658
659 #[test]
660 fn test_imported_flag_defaults_to_false() {
661 // Simulate deserializing a thread without the imported field (backwards compatibility).
662 let json = r#"{
663 "title": "Old Thread",
664 "messages": [],
665 "updated_at": "2024-01-01T00:00:00Z"
666 }"#;
667
668 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
669
670 assert!(
671 !db_thread.imported,
672 "Legacy threads without imported field should default to false"
673 );
674 }
675
676 fn session_id(value: &str) -> acp::SessionId {
677 acp::SessionId::new(Arc::<str>::from(value))
678 }
679
680 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
681 DbThread {
682 title: title.to_string().into(),
683 messages: Vec::new(),
684 updated_at,
685 detailed_summary: None,
686 initial_project_snapshot: None,
687 cumulative_token_usage: Default::default(),
688 request_token_usage: HashMap::default(),
689 model: None,
690 profile: None,
691 imported: false,
692 subagent_context: None,
693 speed: None,
694 thinking_enabled: false,
695 thinking_effort: None,
696 draft_prompt: None,
697 ui_scroll_position: None,
698 }
699 }
700
701 #[gpui::test]
702 async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
703 let database = ThreadsDatabase::new(cx.executor()).unwrap();
704
705 let older_id = session_id("thread-a");
706 let newer_id = session_id("thread-b");
707
708 let older_thread = make_thread(
709 "Thread A",
710 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
711 );
712 let newer_thread = make_thread(
713 "Thread B",
714 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
715 );
716
717 database
718 .save_thread(older_id.clone(), older_thread, PathList::default())
719 .await
720 .unwrap();
721 database
722 .save_thread(newer_id.clone(), newer_thread, PathList::default())
723 .await
724 .unwrap();
725
726 let entries = database.list_threads().await.unwrap();
727 assert_eq!(entries.len(), 2);
728 assert_eq!(entries[0].id, newer_id);
729 assert_eq!(entries[1].id, older_id);
730 }
731
732 #[gpui::test]
733 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
734 let database = ThreadsDatabase::new(cx.executor()).unwrap();
735
736 let thread_id = session_id("thread-a");
737 let original_thread = make_thread(
738 "Thread A",
739 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
740 );
741 let updated_thread = make_thread(
742 "Thread B",
743 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
744 );
745
746 database
747 .save_thread(thread_id.clone(), original_thread, PathList::default())
748 .await
749 .unwrap();
750 database
751 .save_thread(thread_id.clone(), updated_thread, PathList::default())
752 .await
753 .unwrap();
754
755 let entries = database.list_threads().await.unwrap();
756 assert_eq!(entries.len(), 1);
757 assert_eq!(entries[0].id, thread_id);
758 assert_eq!(entries[0].title.as_ref(), "Thread B");
759 assert_eq!(
760 entries[0].updated_at,
761 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
762 );
763 assert!(
764 entries[0].created_at.is_some(),
765 "created_at should be populated"
766 );
767 }
768
769 #[test]
770 fn test_subagent_context_defaults_to_none() {
771 let json = r#"{
772 "title": "Old Thread",
773 "messages": [],
774 "updated_at": "2024-01-01T00:00:00Z"
775 }"#;
776
777 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
778
779 assert!(
780 db_thread.subagent_context.is_none(),
781 "Legacy threads without subagent_context should default to None"
782 );
783 }
784
785 #[test]
786 fn test_draft_prompt_defaults_to_none() {
787 let json = r#"{
788 "title": "Old Thread",
789 "messages": [],
790 "updated_at": "2024-01-01T00:00:00Z"
791 }"#;
792
793 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
794
795 assert!(
796 db_thread.draft_prompt.is_none(),
797 "Legacy threads without draft_prompt field should default to None"
798 );
799 }
800
801 #[gpui::test]
802 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
803 let database = ThreadsDatabase::new(cx.executor()).unwrap();
804
805 let parent_id = session_id("parent-thread");
806 let child_id = session_id("child-thread");
807
808 let mut child_thread = make_thread(
809 "Subagent Thread",
810 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
811 );
812 child_thread.subagent_context = Some(crate::SubagentContext {
813 parent_thread_id: parent_id.clone(),
814 depth: 2,
815 });
816
817 database
818 .save_thread(child_id.clone(), child_thread, PathList::default())
819 .await
820 .unwrap();
821
822 let loaded = database
823 .load_thread(child_id)
824 .await
825 .unwrap()
826 .expect("thread should exist");
827
828 let context = loaded
829 .subagent_context
830 .expect("subagent_context should be restored");
831 assert_eq!(context.parent_thread_id, parent_id);
832 assert_eq!(context.depth, 2);
833 }
834
835 #[gpui::test]
836 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
837 let database = ThreadsDatabase::new(cx.executor()).unwrap();
838
839 let thread_id = session_id("regular-thread");
840 let thread = make_thread(
841 "Regular Thread",
842 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
843 );
844
845 database
846 .save_thread(thread_id.clone(), thread, PathList::default())
847 .await
848 .unwrap();
849
850 let loaded = database
851 .load_thread(thread_id)
852 .await
853 .unwrap()
854 .expect("thread should exist");
855
856 assert!(
857 loaded.subagent_context.is_none(),
858 "Regular threads should have no subagent_context"
859 );
860 }
861
862 #[gpui::test]
863 async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
864 let database = ThreadsDatabase::new(cx.executor()).unwrap();
865
866 let thread_id = session_id("folder-thread");
867 let thread = make_thread(
868 "Folder Thread",
869 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
870 );
871
872 let folder_paths = PathList::new(&[
873 std::path::PathBuf::from("/home/user/project-a"),
874 std::path::PathBuf::from("/home/user/project-b"),
875 ]);
876
877 database
878 .save_thread(thread_id.clone(), thread, folder_paths.clone())
879 .await
880 .unwrap();
881
882 let threads = database.list_threads().await.unwrap();
883 assert_eq!(threads.len(), 1);
884 assert_eq!(threads[0].folder_paths, folder_paths);
885 }
886
887 #[gpui::test]
888 async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
889 let database = ThreadsDatabase::new(cx.executor()).unwrap();
890
891 let thread_id = session_id("no-folder-thread");
892 let thread = make_thread(
893 "No Folder Thread",
894 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
895 );
896
897 database
898 .save_thread(thread_id.clone(), thread, PathList::default())
899 .await
900 .unwrap();
901
902 let threads = database.list_threads().await.unwrap();
903 assert_eq!(threads.len(), 1);
904 assert!(threads[0].folder_paths.is_empty());
905 }
906
907 #[test]
908 fn test_scroll_position_defaults_to_none() {
909 let json = r#"{
910 "title": "Old Thread",
911 "messages": [],
912 "updated_at": "2024-01-01T00:00:00Z"
913 }"#;
914
915 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
916
917 assert!(
918 db_thread.ui_scroll_position.is_none(),
919 "Legacy threads without scroll_position field should default to None"
920 );
921 }
922
923 #[gpui::test]
924 async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
925 let database = ThreadsDatabase::new(cx.executor()).unwrap();
926
927 let thread_id = session_id("thread-with-scroll");
928
929 let mut thread = make_thread(
930 "Thread With Scroll",
931 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
932 );
933 thread.ui_scroll_position = Some(SerializedScrollPosition {
934 item_ix: 42,
935 offset_in_item: 13.5,
936 });
937
938 database
939 .save_thread(thread_id.clone(), thread, PathList::default())
940 .await
941 .unwrap();
942
943 let loaded = database
944 .load_thread(thread_id)
945 .await
946 .unwrap()
947 .expect("thread should exist");
948
949 let scroll = loaded
950 .ui_scroll_position
951 .expect("scroll_position should be restored");
952 assert_eq!(scroll.item_ix, 42);
953 assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
954 }
955}