1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use language_model::Speed;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use util::path_list::PathList;
22use zed_env_vars::ZED_STATELESS;
23
24pub type DbMessage = crate::Message;
25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct DbThreadMetadata {
30 pub id: acp::SessionId,
31 pub parent_session_id: Option<acp::SessionId>,
32 #[serde(alias = "summary")]
33 pub title: SharedString,
34 pub updated_at: DateTime<Utc>,
35 pub created_at: Option<DateTime<Utc>>,
36 /// The workspace folder paths this thread was created against, sorted
37 /// lexicographically. Used for grouping threads by project in the sidebar.
38 pub folder_paths: PathList,
39}
40
41impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
42 fn from(meta: &DbThreadMetadata) -> Self {
43 Self {
44 session_id: meta.id.clone(),
45 cwd: None,
46 title: Some(meta.title.clone()),
47 updated_at: Some(meta.updated_at),
48 created_at: meta.created_at,
49 meta: None,
50 }
51 }
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55pub struct DbThread {
56 pub title: SharedString,
57 pub messages: Vec<DbMessage>,
58 pub updated_at: DateTime<Utc>,
59 #[serde(default)]
60 pub detailed_summary: Option<SharedString>,
61 #[serde(default)]
62 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
63 #[serde(default)]
64 pub cumulative_token_usage: language_model::TokenUsage,
65 #[serde(default)]
66 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
67 #[serde(default)]
68 pub model: Option<DbLanguageModel>,
69 #[serde(default)]
70 pub profile: Option<AgentProfileId>,
71 #[serde(default)]
72 pub imported: bool,
73 #[serde(default)]
74 pub subagent_context: Option<crate::SubagentContext>,
75 #[serde(default)]
76 pub speed: Option<Speed>,
77 #[serde(default)]
78 pub thinking_enabled: bool,
79 #[serde(default)]
80 pub thinking_effort: Option<String>,
81 #[serde(default)]
82 pub draft_prompt: Option<Vec<acp::ContentBlock>>,
83 #[serde(default)]
84 pub provisional_title: Option<SharedString>,
85 #[serde(default)]
86 pub ui_scroll_position: Option<SerializedScrollPosition>,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
90pub struct SerializedScrollPosition {
91 pub item_ix: usize,
92 pub offset_in_item: f32,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct SharedThread {
97 pub title: SharedString,
98 pub messages: Vec<DbMessage>,
99 pub updated_at: DateTime<Utc>,
100 #[serde(default)]
101 pub model: Option<DbLanguageModel>,
102 pub version: String,
103}
104
105impl SharedThread {
106 pub const VERSION: &'static str = "1.0.0";
107
108 pub fn from_db_thread(thread: &DbThread) -> Self {
109 Self {
110 title: thread.title.clone(),
111 messages: thread.messages.clone(),
112 updated_at: thread.updated_at,
113 model: thread.model.clone(),
114 version: Self::VERSION.to_string(),
115 }
116 }
117
118 pub fn to_db_thread(self) -> DbThread {
119 DbThread {
120 title: format!("🔗 {}", self.title).into(),
121 messages: self.messages,
122 updated_at: self.updated_at,
123 detailed_summary: None,
124 initial_project_snapshot: None,
125 cumulative_token_usage: Default::default(),
126 request_token_usage: Default::default(),
127 model: self.model,
128 profile: None,
129 imported: true,
130 subagent_context: None,
131 speed: None,
132 thinking_enabled: false,
133 thinking_effort: None,
134 draft_prompt: None,
135 provisional_title: None,
136 ui_scroll_position: None,
137 }
138 }
139
140 pub fn to_bytes(&self) -> Result<Vec<u8>> {
141 const COMPRESSION_LEVEL: i32 = 3;
142 let json = serde_json::to_vec(self)?;
143 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
144 Ok(compressed)
145 }
146
147 pub fn from_bytes(data: &[u8]) -> Result<Self> {
148 let decompressed = zstd::decode_all(data)?;
149 Ok(serde_json::from_slice(&decompressed)?)
150 }
151}
152
153impl DbThread {
154 pub const VERSION: &'static str = "0.3.0";
155
156 pub fn from_json(json: &[u8]) -> Result<Self> {
157 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
158 match saved_thread_json.get("version") {
159 Some(serde_json::Value::String(version)) => match version.as_str() {
160 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
161 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
162 json,
163 )?),
164 },
165 _ => {
166 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
167 }
168 }
169 }
170
171 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
172 let mut messages = Vec::new();
173 let mut request_token_usage = HashMap::default();
174
175 let mut last_user_message_id = None;
176 for (ix, msg) in thread.messages.into_iter().enumerate() {
177 let message = match msg.role {
178 language_model::Role::User => {
179 let mut content = Vec::new();
180
181 // Convert segments to content
182 for segment in msg.segments {
183 match segment {
184 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
185 content.push(UserMessageContent::Text(text));
186 }
187 crate::legacy_thread::SerializedMessageSegment::Thinking {
188 text,
189 ..
190 } => {
191 // User messages don't have thinking segments, but handle gracefully
192 content.push(UserMessageContent::Text(text));
193 }
194 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
195 ..
196 } => {
197 // User messages don't have redacted thinking, skip.
198 }
199 }
200 }
201
202 // If no content was added, add context as text if available
203 if content.is_empty() && !msg.context.is_empty() {
204 content.push(UserMessageContent::Text(msg.context));
205 }
206
207 let id = UserMessageId::new();
208 last_user_message_id = Some(id.clone());
209
210 crate::Message::User(UserMessage {
211 // MessageId from old format can't be meaningfully converted, so generate a new one
212 id,
213 content,
214 })
215 }
216 language_model::Role::Assistant => {
217 let mut content = Vec::new();
218
219 // Convert segments to content
220 for segment in msg.segments {
221 match segment {
222 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
223 content.push(AgentMessageContent::Text(text));
224 }
225 crate::legacy_thread::SerializedMessageSegment::Thinking {
226 text,
227 signature,
228 } => {
229 content.push(AgentMessageContent::Thinking { text, signature });
230 }
231 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
232 data,
233 } => {
234 content.push(AgentMessageContent::RedactedThinking(data));
235 }
236 }
237 }
238
239 // Convert tool uses
240 let mut tool_names_by_id = HashMap::default();
241 for tool_use in msg.tool_uses {
242 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
243 content.push(AgentMessageContent::ToolUse(
244 language_model::LanguageModelToolUse {
245 id: tool_use.id,
246 name: tool_use.name.into(),
247 raw_input: serde_json::to_string(&tool_use.input)
248 .unwrap_or_default(),
249 input: tool_use.input,
250 is_input_complete: true,
251 thought_signature: None,
252 },
253 ));
254 }
255
256 // Convert tool results
257 let mut tool_results = IndexMap::default();
258 for tool_result in msg.tool_results {
259 let name = tool_names_by_id
260 .remove(&tool_result.tool_use_id)
261 .unwrap_or_else(|| SharedString::from("unknown"));
262 tool_results.insert(
263 tool_result.tool_use_id.clone(),
264 language_model::LanguageModelToolResult {
265 tool_use_id: tool_result.tool_use_id,
266 tool_name: name.into(),
267 is_error: tool_result.is_error,
268 content: tool_result.content,
269 output: tool_result.output,
270 },
271 );
272 }
273
274 if let Some(last_user_message_id) = &last_user_message_id
275 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
276 {
277 request_token_usage.insert(last_user_message_id.clone(), token_usage);
278 }
279
280 crate::Message::Agent(AgentMessage {
281 content,
282 tool_results,
283 reasoning_details: None,
284 })
285 }
286 language_model::Role::System => {
287 // Skip system messages as they're not supported in the new format
288 continue;
289 }
290 };
291
292 messages.push(message);
293 }
294
295 Ok(Self {
296 title: thread.summary,
297 messages,
298 updated_at: thread.updated_at,
299 detailed_summary: match thread.detailed_summary_state {
300 crate::legacy_thread::DetailedSummaryState::NotGenerated
301 | crate::legacy_thread::DetailedSummaryState::Generating => None,
302 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
303 },
304 initial_project_snapshot: thread.initial_project_snapshot,
305 cumulative_token_usage: thread.cumulative_token_usage,
306 request_token_usage,
307 model: thread.model,
308 profile: thread.profile,
309 imported: false,
310 subagent_context: None,
311 speed: None,
312 thinking_enabled: false,
313 thinking_effort: None,
314 draft_prompt: None,
315 provisional_title: None,
316 ui_scroll_position: None,
317 })
318 }
319}
320
321#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
322pub enum DataType {
323 #[serde(rename = "json")]
324 Json,
325 #[serde(rename = "zstd")]
326 Zstd,
327}
328
329impl Bind for DataType {
330 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
331 let value = match self {
332 DataType::Json => "json",
333 DataType::Zstd => "zstd",
334 };
335 value.bind(statement, start_index)
336 }
337}
338
339impl Column for DataType {
340 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
341 let (value, next_index) = String::column(statement, start_index)?;
342 let data_type = match value.as_str() {
343 "json" => DataType::Json,
344 "zstd" => DataType::Zstd,
345 _ => anyhow::bail!("Unknown data type: {}", value),
346 };
347 Ok((data_type, next_index))
348 }
349}
350
351pub(crate) struct ThreadsDatabase {
352 executor: BackgroundExecutor,
353 connection: Arc<Mutex<Connection>>,
354}
355
356struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
357
358impl Global for GlobalThreadsDatabase {}
359
360impl ThreadsDatabase {
361 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
362 if cx.has_global::<GlobalThreadsDatabase>() {
363 return cx.global::<GlobalThreadsDatabase>().0.clone();
364 }
365 let executor = cx.background_executor().clone();
366 let task = executor
367 .spawn({
368 let executor = executor.clone();
369 async move {
370 match ThreadsDatabase::new(executor) {
371 Ok(db) => Ok(Arc::new(db)),
372 Err(err) => Err(Arc::new(err)),
373 }
374 }
375 })
376 .shared();
377
378 cx.set_global(GlobalThreadsDatabase(task.clone()));
379 task
380 }
381
382 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
383 let connection = if *ZED_STATELESS {
384 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
385 } else if cfg!(any(feature = "test-support", test)) {
386 // rust stores the name of the test on the current thread.
387 // We use this to automatically create a database that will
388 // be shared within the test (for the test_retrieve_old_thread)
389 // but not with concurrent tests.
390 let thread = std::thread::current();
391 let test_name = thread.name();
392 Connection::open_memory(Some(&format!(
393 "THREAD_FALLBACK_{}",
394 test_name.unwrap_or_default()
395 )))
396 } else {
397 let threads_dir = paths::data_dir().join("threads");
398 std::fs::create_dir_all(&threads_dir)?;
399 let sqlite_path = threads_dir.join("threads.db");
400 Connection::open_file(&sqlite_path.to_string_lossy())
401 };
402
403 connection.exec(indoc! {"
404 CREATE TABLE IF NOT EXISTS threads (
405 id TEXT PRIMARY KEY,
406 summary TEXT NOT NULL,
407 updated_at TEXT NOT NULL,
408 data_type TEXT NOT NULL,
409 data BLOB NOT NULL
410 )
411 "})?()
412 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
413
414 if let Ok(mut s) = connection.exec(indoc! {"
415 ALTER TABLE threads ADD COLUMN parent_id TEXT
416 "})
417 {
418 s().ok();
419 }
420
421 if let Ok(mut s) = connection.exec(indoc! {"
422 ALTER TABLE threads ADD COLUMN folder_paths TEXT;
423 ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
424 "})
425 {
426 s().ok();
427 }
428
429 if let Ok(mut s) = connection.exec(indoc! {"
430 ALTER TABLE threads ADD COLUMN created_at TEXT;
431 "})
432 {
433 if s().is_ok() {
434 connection.exec(indoc! {"
435 UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
436 "})?()?;
437 }
438 }
439
440 let db = Self {
441 executor,
442 connection: Arc::new(Mutex::new(connection)),
443 };
444
445 Ok(db)
446 }
447
448 fn save_thread_sync(
449 connection: &Arc<Mutex<Connection>>,
450 id: acp::SessionId,
451 thread: DbThread,
452 folder_paths: &PathList,
453 ) -> Result<()> {
454 const COMPRESSION_LEVEL: i32 = 3;
455
456 #[derive(Serialize)]
457 struct SerializedThread {
458 #[serde(flatten)]
459 thread: DbThread,
460 version: &'static str,
461 }
462
463 let title = thread.title.to_string();
464 let updated_at = thread.updated_at.to_rfc3339();
465 let parent_id = thread
466 .subagent_context
467 .as_ref()
468 .map(|ctx| ctx.parent_thread_id.0.clone());
469 let serialized_folder_paths = folder_paths.serialize();
470 let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
471 if folder_paths.is_empty() {
472 (None, None)
473 } else {
474 (
475 Some(serialized_folder_paths.paths),
476 Some(serialized_folder_paths.order),
477 )
478 };
479 let json_data = serde_json::to_string(&SerializedThread {
480 thread,
481 version: DbThread::VERSION,
482 })?;
483
484 let connection = connection.lock();
485
486 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
487 let data_type = DataType::Zstd;
488 let data = compressed;
489
490 // Use the thread's updated_at as created_at for new threads.
491 // This ensures the creation time reflects when the thread was conceptually
492 // created, not when it was saved to the database.
493 let created_at = updated_at.clone();
494
495 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
496 INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
497 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
498 ON CONFLICT(id) DO UPDATE SET
499 parent_id = excluded.parent_id,
500 folder_paths = excluded.folder_paths,
501 folder_paths_order = excluded.folder_paths_order,
502 summary = excluded.summary,
503 updated_at = excluded.updated_at,
504 data_type = excluded.data_type,
505 data = excluded.data
506 "})?;
507
508 insert((
509 id.0,
510 parent_id,
511 folder_paths_str,
512 folder_paths_order_str,
513 title,
514 updated_at,
515 data_type,
516 data,
517 created_at,
518 ))?;
519
520 Ok(())
521 }
522
523 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
524 let connection = self.connection.clone();
525
526 self.executor.spawn(async move {
527 let connection = connection.lock();
528
529 let mut select = connection
530 .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
531 SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
532 "})?;
533
534 let rows = select(())?;
535 let mut threads = Vec::new();
536
537 for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
538 let folder_paths = folder_paths
539 .map(|paths| {
540 PathList::deserialize(&util::path_list::SerializedPathList {
541 paths,
542 order: folder_paths_order.unwrap_or_default(),
543 })
544 })
545 .unwrap_or_default();
546 let created_at = created_at
547 .as_deref()
548 .map(DateTime::parse_from_rfc3339)
549 .transpose()?
550 .map(|dt| dt.with_timezone(&Utc));
551
552 threads.push(DbThreadMetadata {
553 id: acp::SessionId::new(id),
554 parent_session_id: parent_id.map(acp::SessionId::new),
555 title: summary.into(),
556 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
557 created_at,
558 folder_paths,
559 });
560 }
561
562 Ok(threads)
563 })
564 }
565
566 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
567 let connection = self.connection.clone();
568
569 self.executor.spawn(async move {
570 let connection = connection.lock();
571 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
572 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
573 "})?;
574
575 let rows = select(id.0)?;
576 if let Some((data_type, data)) = rows.into_iter().next() {
577 let json_data = match data_type {
578 DataType::Zstd => {
579 let decompressed = zstd::decode_all(&data[..])?;
580 String::from_utf8(decompressed)?
581 }
582 DataType::Json => String::from_utf8(data)?,
583 };
584 let thread = DbThread::from_json(json_data.as_bytes())?;
585 Ok(Some(thread))
586 } else {
587 Ok(None)
588 }
589 })
590 }
591
592 pub fn save_thread(
593 &self,
594 id: acp::SessionId,
595 thread: DbThread,
596 folder_paths: PathList,
597 ) -> Task<Result<()>> {
598 let connection = self.connection.clone();
599
600 self.executor
601 .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
602 }
603
604 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
605 let connection = self.connection.clone();
606
607 self.executor.spawn(async move {
608 let connection = connection.lock();
609
610 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
611 DELETE FROM threads WHERE id = ?
612 "})?;
613
614 delete(id.0)?;
615
616 Ok(())
617 })
618 }
619
620 pub fn delete_threads(&self) -> Task<Result<()>> {
621 let connection = self.connection.clone();
622
623 self.executor.spawn(async move {
624 let connection = connection.lock();
625
626 let mut delete = connection.exec_bound::<()>(indoc! {"
627 DELETE FROM threads
628 "})?;
629
630 delete(())?;
631
632 Ok(())
633 })
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640 use chrono::{DateTime, TimeZone, Utc};
641 use collections::HashMap;
642 use gpui::TestAppContext;
643 use std::sync::Arc;
644
645 #[test]
646 fn test_shared_thread_roundtrip() {
647 let original = SharedThread {
648 title: "Test Thread".into(),
649 messages: vec![],
650 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
651 model: None,
652 version: SharedThread::VERSION.to_string(),
653 };
654
655 let bytes = original.to_bytes().expect("Failed to serialize");
656 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
657
658 assert_eq!(restored.title, original.title);
659 assert_eq!(restored.version, original.version);
660 assert_eq!(restored.updated_at, original.updated_at);
661 }
662
663 #[test]
664 fn test_imported_flag_defaults_to_false() {
665 // Simulate deserializing a thread without the imported field (backwards compatibility).
666 let json = r#"{
667 "title": "Old Thread",
668 "messages": [],
669 "updated_at": "2024-01-01T00:00:00Z"
670 }"#;
671
672 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
673
674 assert!(
675 !db_thread.imported,
676 "Legacy threads without imported field should default to false"
677 );
678 }
679
680 fn session_id(value: &str) -> acp::SessionId {
681 acp::SessionId::new(Arc::<str>::from(value))
682 }
683
684 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
685 DbThread {
686 title: title.to_string().into(),
687 messages: Vec::new(),
688 updated_at,
689 detailed_summary: None,
690 initial_project_snapshot: None,
691 cumulative_token_usage: Default::default(),
692 request_token_usage: HashMap::default(),
693 model: None,
694 profile: None,
695 imported: false,
696 subagent_context: None,
697 speed: None,
698 thinking_enabled: false,
699 thinking_effort: None,
700 draft_prompt: None,
701 provisional_title: None,
702 ui_scroll_position: None,
703 }
704 }
705
706 #[gpui::test]
707 async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
708 let database = ThreadsDatabase::new(cx.executor()).unwrap();
709
710 let older_id = session_id("thread-a");
711 let newer_id = session_id("thread-b");
712
713 let older_thread = make_thread(
714 "Thread A",
715 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
716 );
717 let newer_thread = make_thread(
718 "Thread B",
719 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
720 );
721
722 database
723 .save_thread(older_id.clone(), older_thread, PathList::default())
724 .await
725 .unwrap();
726 database
727 .save_thread(newer_id.clone(), newer_thread, PathList::default())
728 .await
729 .unwrap();
730
731 let entries = database.list_threads().await.unwrap();
732 assert_eq!(entries.len(), 2);
733 assert_eq!(entries[0].id, newer_id);
734 assert_eq!(entries[1].id, older_id);
735 }
736
737 #[gpui::test]
738 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
739 let database = ThreadsDatabase::new(cx.executor()).unwrap();
740
741 let thread_id = session_id("thread-a");
742 let original_thread = make_thread(
743 "Thread A",
744 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
745 );
746 let updated_thread = make_thread(
747 "Thread B",
748 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
749 );
750
751 database
752 .save_thread(thread_id.clone(), original_thread, PathList::default())
753 .await
754 .unwrap();
755 database
756 .save_thread(thread_id.clone(), updated_thread, PathList::default())
757 .await
758 .unwrap();
759
760 let entries = database.list_threads().await.unwrap();
761 assert_eq!(entries.len(), 1);
762 assert_eq!(entries[0].id, thread_id);
763 assert_eq!(entries[0].title.as_ref(), "Thread B");
764 assert_eq!(
765 entries[0].updated_at,
766 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
767 );
768 assert!(
769 entries[0].created_at.is_some(),
770 "created_at should be populated"
771 );
772 }
773
774 #[test]
775 fn test_subagent_context_defaults_to_none() {
776 let json = r#"{
777 "title": "Old Thread",
778 "messages": [],
779 "updated_at": "2024-01-01T00:00:00Z"
780 }"#;
781
782 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
783
784 assert!(
785 db_thread.subagent_context.is_none(),
786 "Legacy threads without subagent_context should default to None"
787 );
788 }
789
790 #[test]
791 fn test_draft_prompt_defaults_to_none() {
792 let json = r#"{
793 "title": "Old Thread",
794 "messages": [],
795 "updated_at": "2024-01-01T00:00:00Z"
796 }"#;
797
798 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
799
800 assert!(
801 db_thread.draft_prompt.is_none(),
802 "Legacy threads without draft_prompt field should default to None"
803 );
804 }
805
806 #[gpui::test]
807 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
808 let database = ThreadsDatabase::new(cx.executor()).unwrap();
809
810 let parent_id = session_id("parent-thread");
811 let child_id = session_id("child-thread");
812
813 let mut child_thread = make_thread(
814 "Subagent Thread",
815 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
816 );
817 child_thread.subagent_context = Some(crate::SubagentContext {
818 parent_thread_id: parent_id.clone(),
819 depth: 2,
820 });
821
822 database
823 .save_thread(child_id.clone(), child_thread, PathList::default())
824 .await
825 .unwrap();
826
827 let loaded = database
828 .load_thread(child_id)
829 .await
830 .unwrap()
831 .expect("thread should exist");
832
833 let context = loaded
834 .subagent_context
835 .expect("subagent_context should be restored");
836 assert_eq!(context.parent_thread_id, parent_id);
837 assert_eq!(context.depth, 2);
838 }
839
840 #[gpui::test]
841 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
842 let database = ThreadsDatabase::new(cx.executor()).unwrap();
843
844 let thread_id = session_id("regular-thread");
845 let thread = make_thread(
846 "Regular Thread",
847 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
848 );
849
850 database
851 .save_thread(thread_id.clone(), thread, PathList::default())
852 .await
853 .unwrap();
854
855 let loaded = database
856 .load_thread(thread_id)
857 .await
858 .unwrap()
859 .expect("thread should exist");
860
861 assert!(
862 loaded.subagent_context.is_none(),
863 "Regular threads should have no subagent_context"
864 );
865 }
866
867 #[gpui::test]
868 async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
869 let database = ThreadsDatabase::new(cx.executor()).unwrap();
870
871 let thread_id = session_id("folder-thread");
872 let thread = make_thread(
873 "Folder Thread",
874 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
875 );
876
877 let folder_paths = PathList::new(&[
878 std::path::PathBuf::from("/home/user/project-a"),
879 std::path::PathBuf::from("/home/user/project-b"),
880 ]);
881
882 database
883 .save_thread(thread_id.clone(), thread, folder_paths.clone())
884 .await
885 .unwrap();
886
887 let threads = database.list_threads().await.unwrap();
888 assert_eq!(threads.len(), 1);
889 assert_eq!(threads[0].folder_paths, folder_paths);
890 }
891
892 #[gpui::test]
893 async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
894 let database = ThreadsDatabase::new(cx.executor()).unwrap();
895
896 let thread_id = session_id("no-folder-thread");
897 let thread = make_thread(
898 "No Folder Thread",
899 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
900 );
901
902 database
903 .save_thread(thread_id.clone(), thread, PathList::default())
904 .await
905 .unwrap();
906
907 let threads = database.list_threads().await.unwrap();
908 assert_eq!(threads.len(), 1);
909 assert!(threads[0].folder_paths.is_empty());
910 }
911
912 #[test]
913 fn test_scroll_position_defaults_to_none() {
914 let json = r#"{
915 "title": "Old Thread",
916 "messages": [],
917 "updated_at": "2024-01-01T00:00:00Z"
918 }"#;
919
920 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
921
922 assert!(
923 db_thread.ui_scroll_position.is_none(),
924 "Legacy threads without scroll_position field should default to None"
925 );
926 }
927
928 #[gpui::test]
929 async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
930 let database = ThreadsDatabase::new(cx.executor()).unwrap();
931
932 let thread_id = session_id("thread-with-scroll");
933
934 let mut thread = make_thread(
935 "Thread With Scroll",
936 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
937 );
938 thread.ui_scroll_position = Some(SerializedScrollPosition {
939 item_ix: 42,
940 offset_in_item: 13.5,
941 });
942
943 database
944 .save_thread(thread_id.clone(), thread, PathList::default())
945 .await
946 .unwrap();
947
948 let loaded = database
949 .load_thread(thread_id)
950 .await
951 .unwrap()
952 .expect("thread should exist");
953
954 let scroll = loaded
955 .ui_scroll_position
956 .expect("scroll_position should be restored");
957 assert_eq!(scroll.item_ix, 42);
958 assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
959 }
960}