1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use language_model::Speed;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use util::path_list::PathList;
22use zed_env_vars::ZED_STATELESS;
23
24pub type DbMessage = crate::Message;
25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct DbThreadMetadata {
30 pub id: acp::SessionId,
31 pub parent_session_id: Option<acp::SessionId>,
32 #[serde(alias = "summary")]
33 pub title: SharedString,
34 pub updated_at: DateTime<Utc>,
35 /// The workspace folder paths this thread was created against, sorted
36 /// lexicographically. Used for grouping threads by project in the sidebar.
37 pub folder_paths: PathList,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41pub struct DbThread {
42 pub title: SharedString,
43 pub messages: Vec<DbMessage>,
44 pub updated_at: DateTime<Utc>,
45 #[serde(default)]
46 pub detailed_summary: Option<SharedString>,
47 #[serde(default)]
48 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
49 #[serde(default)]
50 pub cumulative_token_usage: language_model::TokenUsage,
51 #[serde(default)]
52 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
53 #[serde(default)]
54 pub model: Option<DbLanguageModel>,
55 #[serde(default)]
56 pub profile: Option<AgentProfileId>,
57 #[serde(default)]
58 pub imported: bool,
59 #[serde(default)]
60 pub subagent_context: Option<crate::SubagentContext>,
61 #[serde(default)]
62 pub speed: Option<Speed>,
63 #[serde(default)]
64 pub thinking_enabled: bool,
65 #[serde(default)]
66 pub thinking_effort: Option<String>,
67 #[serde(default)]
68 pub draft_prompt: Option<Vec<acp::ContentBlock>>,
69 #[serde(default)]
70 pub ui_scroll_position: Option<SerializedScrollPosition>,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub struct SerializedScrollPosition {
75 pub item_ix: usize,
76 pub offset_in_item: f32,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct SharedThread {
81 pub title: SharedString,
82 pub messages: Vec<DbMessage>,
83 pub updated_at: DateTime<Utc>,
84 #[serde(default)]
85 pub model: Option<DbLanguageModel>,
86 pub version: String,
87}
88
89impl SharedThread {
90 pub const VERSION: &'static str = "1.0.0";
91
92 pub fn from_db_thread(thread: &DbThread) -> Self {
93 Self {
94 title: thread.title.clone(),
95 messages: thread.messages.clone(),
96 updated_at: thread.updated_at,
97 model: thread.model.clone(),
98 version: Self::VERSION.to_string(),
99 }
100 }
101
102 pub fn to_db_thread(self) -> DbThread {
103 DbThread {
104 title: format!("🔗 {}", self.title).into(),
105 messages: self.messages,
106 updated_at: self.updated_at,
107 detailed_summary: None,
108 initial_project_snapshot: None,
109 cumulative_token_usage: Default::default(),
110 request_token_usage: Default::default(),
111 model: self.model,
112 profile: None,
113 imported: true,
114 subagent_context: None,
115 speed: None,
116 thinking_enabled: false,
117 thinking_effort: None,
118 draft_prompt: None,
119 ui_scroll_position: None,
120 }
121 }
122
123 pub fn to_bytes(&self) -> Result<Vec<u8>> {
124 const COMPRESSION_LEVEL: i32 = 3;
125 let json = serde_json::to_vec(self)?;
126 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
127 Ok(compressed)
128 }
129
130 pub fn from_bytes(data: &[u8]) -> Result<Self> {
131 let decompressed = zstd::decode_all(data)?;
132 Ok(serde_json::from_slice(&decompressed)?)
133 }
134}
135
136impl DbThread {
137 pub const VERSION: &'static str = "0.3.0";
138
139 pub fn from_json(json: &[u8]) -> Result<Self> {
140 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
141 match saved_thread_json.get("version") {
142 Some(serde_json::Value::String(version)) => match version.as_str() {
143 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
144 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
145 json,
146 )?),
147 },
148 _ => {
149 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
150 }
151 }
152 }
153
154 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
155 let mut messages = Vec::new();
156 let mut request_token_usage = HashMap::default();
157
158 let mut last_user_message_id = None;
159 for (ix, msg) in thread.messages.into_iter().enumerate() {
160 let message = match msg.role {
161 language_model::Role::User => {
162 let mut content = Vec::new();
163
164 // Convert segments to content
165 for segment in msg.segments {
166 match segment {
167 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
168 content.push(UserMessageContent::Text(text));
169 }
170 crate::legacy_thread::SerializedMessageSegment::Thinking {
171 text,
172 ..
173 } => {
174 // User messages don't have thinking segments, but handle gracefully
175 content.push(UserMessageContent::Text(text));
176 }
177 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
178 ..
179 } => {
180 // User messages don't have redacted thinking, skip.
181 }
182 }
183 }
184
185 // If no content was added, add context as text if available
186 if content.is_empty() && !msg.context.is_empty() {
187 content.push(UserMessageContent::Text(msg.context));
188 }
189
190 let id = UserMessageId::new();
191 last_user_message_id = Some(id.clone());
192
193 crate::Message::User(UserMessage {
194 // MessageId from old format can't be meaningfully converted, so generate a new one
195 id,
196 content,
197 })
198 }
199 language_model::Role::Assistant => {
200 let mut content = Vec::new();
201
202 // Convert segments to content
203 for segment in msg.segments {
204 match segment {
205 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
206 content.push(AgentMessageContent::Text(text));
207 }
208 crate::legacy_thread::SerializedMessageSegment::Thinking {
209 text,
210 signature,
211 } => {
212 content.push(AgentMessageContent::Thinking { text, signature });
213 }
214 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
215 data,
216 } => {
217 content.push(AgentMessageContent::RedactedThinking(data));
218 }
219 }
220 }
221
222 // Convert tool uses
223 let mut tool_names_by_id = HashMap::default();
224 for tool_use in msg.tool_uses {
225 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
226 content.push(AgentMessageContent::ToolUse(
227 language_model::LanguageModelToolUse {
228 id: tool_use.id,
229 name: tool_use.name.into(),
230 raw_input: serde_json::to_string(&tool_use.input)
231 .unwrap_or_default(),
232 input: tool_use.input,
233 is_input_complete: true,
234 thought_signature: None,
235 },
236 ));
237 }
238
239 // Convert tool results
240 let mut tool_results = IndexMap::default();
241 for tool_result in msg.tool_results {
242 let name = tool_names_by_id
243 .remove(&tool_result.tool_use_id)
244 .unwrap_or_else(|| SharedString::from("unknown"));
245 tool_results.insert(
246 tool_result.tool_use_id.clone(),
247 language_model::LanguageModelToolResult {
248 tool_use_id: tool_result.tool_use_id,
249 tool_name: name.into(),
250 is_error: tool_result.is_error,
251 content: tool_result.content,
252 output: tool_result.output,
253 },
254 );
255 }
256
257 if let Some(last_user_message_id) = &last_user_message_id
258 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
259 {
260 request_token_usage.insert(last_user_message_id.clone(), token_usage);
261 }
262
263 crate::Message::Agent(AgentMessage {
264 content,
265 tool_results,
266 reasoning_details: None,
267 })
268 }
269 language_model::Role::System => {
270 // Skip system messages as they're not supported in the new format
271 continue;
272 }
273 };
274
275 messages.push(message);
276 }
277
278 Ok(Self {
279 title: thread.summary,
280 messages,
281 updated_at: thread.updated_at,
282 detailed_summary: match thread.detailed_summary_state {
283 crate::legacy_thread::DetailedSummaryState::NotGenerated
284 | crate::legacy_thread::DetailedSummaryState::Generating => None,
285 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
286 },
287 initial_project_snapshot: thread.initial_project_snapshot,
288 cumulative_token_usage: thread.cumulative_token_usage,
289 request_token_usage,
290 model: thread.model,
291 profile: thread.profile,
292 imported: false,
293 subagent_context: None,
294 speed: None,
295 thinking_enabled: false,
296 thinking_effort: None,
297 draft_prompt: None,
298 ui_scroll_position: None,
299 })
300 }
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
304pub enum DataType {
305 #[serde(rename = "json")]
306 Json,
307 #[serde(rename = "zstd")]
308 Zstd,
309}
310
311impl Bind for DataType {
312 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
313 let value = match self {
314 DataType::Json => "json",
315 DataType::Zstd => "zstd",
316 };
317 value.bind(statement, start_index)
318 }
319}
320
321impl Column for DataType {
322 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
323 let (value, next_index) = String::column(statement, start_index)?;
324 let data_type = match value.as_str() {
325 "json" => DataType::Json,
326 "zstd" => DataType::Zstd,
327 _ => anyhow::bail!("Unknown data type: {}", value),
328 };
329 Ok((data_type, next_index))
330 }
331}
332
333pub(crate) struct ThreadsDatabase {
334 executor: BackgroundExecutor,
335 connection: Arc<Mutex<Connection>>,
336}
337
338struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
339
340impl Global for GlobalThreadsDatabase {}
341
342impl ThreadsDatabase {
343 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
344 if cx.has_global::<GlobalThreadsDatabase>() {
345 return cx.global::<GlobalThreadsDatabase>().0.clone();
346 }
347 let executor = cx.background_executor().clone();
348 let task = executor
349 .spawn({
350 let executor = executor.clone();
351 async move {
352 match ThreadsDatabase::new(executor) {
353 Ok(db) => Ok(Arc::new(db)),
354 Err(err) => Err(Arc::new(err)),
355 }
356 }
357 })
358 .shared();
359
360 cx.set_global(GlobalThreadsDatabase(task.clone()));
361 task
362 }
363
364 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
365 let connection = if *ZED_STATELESS {
366 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
367 } else if cfg!(any(feature = "test-support", test)) {
368 // rust stores the name of the test on the current thread.
369 // We use this to automatically create a database that will
370 // be shared within the test (for the test_retrieve_old_thread)
371 // but not with concurrent tests.
372 let thread = std::thread::current();
373 let test_name = thread.name();
374 Connection::open_memory(Some(&format!(
375 "THREAD_FALLBACK_{}",
376 test_name.unwrap_or_default()
377 )))
378 } else {
379 let threads_dir = paths::data_dir().join("threads");
380 std::fs::create_dir_all(&threads_dir)?;
381 let sqlite_path = threads_dir.join("threads.db");
382 Connection::open_file(&sqlite_path.to_string_lossy())
383 };
384
385 connection.exec("PRAGMA journal_mode=WAL;")?()?;
386 connection.exec("PRAGMA busy_timeout=1000;")?()?;
387
388 connection.exec(indoc! {"
389 CREATE TABLE IF NOT EXISTS threads (
390 id TEXT PRIMARY KEY,
391 summary TEXT NOT NULL,
392 updated_at TEXT NOT NULL,
393 data_type TEXT NOT NULL,
394 data BLOB NOT NULL
395 )
396 "})?()
397 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
398
399 if let Ok(mut s) = connection.exec(indoc! {"
400 ALTER TABLE threads ADD COLUMN parent_id TEXT
401 "})
402 {
403 s().ok();
404 }
405
406 if let Ok(mut s) = connection.exec(indoc! {"
407 ALTER TABLE threads ADD COLUMN folder_paths TEXT;
408 ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
409 "})
410 {
411 s().ok();
412 }
413
414 let db = Self {
415 executor,
416 connection: Arc::new(Mutex::new(connection)),
417 };
418
419 Ok(db)
420 }
421
422 fn save_thread_sync(
423 connection: &Arc<Mutex<Connection>>,
424 id: acp::SessionId,
425 thread: DbThread,
426 folder_paths: &PathList,
427 ) -> Result<()> {
428 const COMPRESSION_LEVEL: i32 = 3;
429
430 #[derive(Serialize)]
431 struct SerializedThread {
432 #[serde(flatten)]
433 thread: DbThread,
434 version: &'static str,
435 }
436
437 let title = thread.title.to_string();
438 let updated_at = thread.updated_at.to_rfc3339();
439 let parent_id = thread
440 .subagent_context
441 .as_ref()
442 .map(|ctx| ctx.parent_thread_id.0.clone());
443 let serialized_folder_paths = folder_paths.serialize();
444 let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
445 if folder_paths.is_empty() {
446 (None, None)
447 } else {
448 (
449 Some(serialized_folder_paths.paths),
450 Some(serialized_folder_paths.order),
451 )
452 };
453 let json_data = serde_json::to_string(&SerializedThread {
454 thread,
455 version: DbThread::VERSION,
456 })?;
457
458 let connection = connection.lock();
459
460 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
461 let data_type = DataType::Zstd;
462 let data = compressed;
463
464 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
465 INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
466 "})?;
467
468 insert((
469 id.0,
470 parent_id,
471 folder_paths_str,
472 folder_paths_order_str,
473 title,
474 updated_at,
475 data_type,
476 data,
477 ))?;
478
479 Ok(())
480 }
481
482 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
483 let connection = self.connection.clone();
484
485 self.executor.spawn(async move {
486 let connection = connection.lock();
487
488 let mut select = connection
489 .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
490 SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
491 "})?;
492
493 let rows = select(())?;
494 let mut threads = Vec::new();
495
496 for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
497 let folder_paths = folder_paths
498 .map(|paths| {
499 PathList::deserialize(&util::path_list::SerializedPathList {
500 paths,
501 order: folder_paths_order.unwrap_or_default(),
502 })
503 })
504 .unwrap_or_default();
505 threads.push(DbThreadMetadata {
506 id: acp::SessionId::new(id),
507 parent_session_id: parent_id.map(acp::SessionId::new),
508 title: summary.into(),
509 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
510 folder_paths,
511 });
512 }
513
514 Ok(threads)
515 })
516 }
517
518 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
519 let connection = self.connection.clone();
520
521 self.executor.spawn(async move {
522 let connection = connection.lock();
523 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
524 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
525 "})?;
526
527 let rows = select(id.0)?;
528 if let Some((data_type, data)) = rows.into_iter().next() {
529 let json_data = match data_type {
530 DataType::Zstd => {
531 let decompressed = zstd::decode_all(&data[..])?;
532 String::from_utf8(decompressed)?
533 }
534 DataType::Json => String::from_utf8(data)?,
535 };
536 let thread = DbThread::from_json(json_data.as_bytes())?;
537 Ok(Some(thread))
538 } else {
539 Ok(None)
540 }
541 })
542 }
543
544 pub fn save_thread(
545 &self,
546 id: acp::SessionId,
547 thread: DbThread,
548 folder_paths: PathList,
549 ) -> Task<Result<()>> {
550 let connection = self.connection.clone();
551
552 self.executor
553 .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
554 }
555
556 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
557 let connection = self.connection.clone();
558
559 self.executor.spawn(async move {
560 let connection = connection.lock();
561
562 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
563 DELETE FROM threads WHERE id = ?
564 "})?;
565
566 delete(id.0)?;
567
568 Ok(())
569 })
570 }
571
572 pub fn delete_threads(&self) -> Task<Result<()>> {
573 let connection = self.connection.clone();
574
575 self.executor.spawn(async move {
576 let connection = connection.lock();
577
578 let mut delete = connection.exec_bound::<()>(indoc! {"
579 DELETE FROM threads
580 "})?;
581
582 delete(())?;
583
584 Ok(())
585 })
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use chrono::{DateTime, TimeZone, Utc};
593 use collections::HashMap;
594 use gpui::TestAppContext;
595 use std::sync::Arc;
596
597 #[test]
598 fn test_shared_thread_roundtrip() {
599 let original = SharedThread {
600 title: "Test Thread".into(),
601 messages: vec![],
602 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
603 model: None,
604 version: SharedThread::VERSION.to_string(),
605 };
606
607 let bytes = original.to_bytes().expect("Failed to serialize");
608 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
609
610 assert_eq!(restored.title, original.title);
611 assert_eq!(restored.version, original.version);
612 assert_eq!(restored.updated_at, original.updated_at);
613 }
614
615 #[test]
616 fn test_imported_flag_defaults_to_false() {
617 // Simulate deserializing a thread without the imported field (backwards compatibility).
618 let json = r#"{
619 "title": "Old Thread",
620 "messages": [],
621 "updated_at": "2024-01-01T00:00:00Z"
622 }"#;
623
624 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
625
626 assert!(
627 !db_thread.imported,
628 "Legacy threads without imported field should default to false"
629 );
630 }
631
632 fn session_id(value: &str) -> acp::SessionId {
633 acp::SessionId::new(Arc::<str>::from(value))
634 }
635
636 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
637 DbThread {
638 title: title.to_string().into(),
639 messages: Vec::new(),
640 updated_at,
641 detailed_summary: None,
642 initial_project_snapshot: None,
643 cumulative_token_usage: Default::default(),
644 request_token_usage: HashMap::default(),
645 model: None,
646 profile: None,
647 imported: false,
648 subagent_context: None,
649 speed: None,
650 thinking_enabled: false,
651 thinking_effort: None,
652 draft_prompt: None,
653 ui_scroll_position: None,
654 }
655 }
656
657 #[gpui::test]
658 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
659 let database = ThreadsDatabase::new(cx.executor()).unwrap();
660
661 let older_id = session_id("thread-a");
662 let newer_id = session_id("thread-b");
663
664 let older_thread = make_thread(
665 "Thread A",
666 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
667 );
668 let newer_thread = make_thread(
669 "Thread B",
670 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
671 );
672
673 database
674 .save_thread(older_id.clone(), older_thread, PathList::default())
675 .await
676 .unwrap();
677 database
678 .save_thread(newer_id.clone(), newer_thread, PathList::default())
679 .await
680 .unwrap();
681
682 let entries = database.list_threads().await.unwrap();
683 assert_eq!(entries.len(), 2);
684 assert_eq!(entries[0].id, newer_id);
685 assert_eq!(entries[1].id, older_id);
686 }
687
688 #[gpui::test]
689 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
690 let database = ThreadsDatabase::new(cx.executor()).unwrap();
691
692 let thread_id = session_id("thread-a");
693 let original_thread = make_thread(
694 "Thread A",
695 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
696 );
697 let updated_thread = make_thread(
698 "Thread B",
699 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
700 );
701
702 database
703 .save_thread(thread_id.clone(), original_thread, PathList::default())
704 .await
705 .unwrap();
706 database
707 .save_thread(thread_id.clone(), updated_thread, PathList::default())
708 .await
709 .unwrap();
710
711 let entries = database.list_threads().await.unwrap();
712 assert_eq!(entries.len(), 1);
713 assert_eq!(entries[0].id, thread_id);
714 assert_eq!(entries[0].title.as_ref(), "Thread B");
715 assert_eq!(
716 entries[0].updated_at,
717 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
718 );
719 }
720
721 #[test]
722 fn test_subagent_context_defaults_to_none() {
723 let json = r#"{
724 "title": "Old Thread",
725 "messages": [],
726 "updated_at": "2024-01-01T00:00:00Z"
727 }"#;
728
729 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
730
731 assert!(
732 db_thread.subagent_context.is_none(),
733 "Legacy threads without subagent_context should default to None"
734 );
735 }
736
737 #[test]
738 fn test_draft_prompt_defaults_to_none() {
739 let json = r#"{
740 "title": "Old Thread",
741 "messages": [],
742 "updated_at": "2024-01-01T00:00:00Z"
743 }"#;
744
745 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
746
747 assert!(
748 db_thread.draft_prompt.is_none(),
749 "Legacy threads without draft_prompt field should default to None"
750 );
751 }
752
753 #[gpui::test]
754 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
755 let database = ThreadsDatabase::new(cx.executor()).unwrap();
756
757 let parent_id = session_id("parent-thread");
758 let child_id = session_id("child-thread");
759
760 let mut child_thread = make_thread(
761 "Subagent Thread",
762 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
763 );
764 child_thread.subagent_context = Some(crate::SubagentContext {
765 parent_thread_id: parent_id.clone(),
766 depth: 2,
767 });
768
769 database
770 .save_thread(child_id.clone(), child_thread, PathList::default())
771 .await
772 .unwrap();
773
774 let loaded = database
775 .load_thread(child_id)
776 .await
777 .unwrap()
778 .expect("thread should exist");
779
780 let context = loaded
781 .subagent_context
782 .expect("subagent_context should be restored");
783 assert_eq!(context.parent_thread_id, parent_id);
784 assert_eq!(context.depth, 2);
785 }
786
787 #[gpui::test]
788 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
789 let database = ThreadsDatabase::new(cx.executor()).unwrap();
790
791 let thread_id = session_id("regular-thread");
792 let thread = make_thread(
793 "Regular Thread",
794 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
795 );
796
797 database
798 .save_thread(thread_id.clone(), thread, PathList::default())
799 .await
800 .unwrap();
801
802 let loaded = database
803 .load_thread(thread_id)
804 .await
805 .unwrap()
806 .expect("thread should exist");
807
808 assert!(
809 loaded.subagent_context.is_none(),
810 "Regular threads should have no subagent_context"
811 );
812 }
813
814 #[gpui::test]
815 async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
816 let database = ThreadsDatabase::new(cx.executor()).unwrap();
817
818 let thread_id = session_id("folder-thread");
819 let thread = make_thread(
820 "Folder Thread",
821 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
822 );
823
824 let folder_paths = PathList::new(&[
825 std::path::PathBuf::from("/home/user/project-a"),
826 std::path::PathBuf::from("/home/user/project-b"),
827 ]);
828
829 database
830 .save_thread(thread_id.clone(), thread, folder_paths.clone())
831 .await
832 .unwrap();
833
834 let threads = database.list_threads().await.unwrap();
835 assert_eq!(threads.len(), 1);
836 assert_eq!(threads[0].folder_paths, folder_paths);
837 }
838
839 #[gpui::test]
840 async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
841 let database = ThreadsDatabase::new(cx.executor()).unwrap();
842
843 let thread_id = session_id("no-folder-thread");
844 let thread = make_thread(
845 "No Folder Thread",
846 Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
847 );
848
849 database
850 .save_thread(thread_id.clone(), thread, PathList::default())
851 .await
852 .unwrap();
853
854 let threads = database.list_threads().await.unwrap();
855 assert_eq!(threads.len(), 1);
856 assert!(threads[0].folder_paths.is_empty());
857 }
858
859 #[test]
860 fn test_scroll_position_defaults_to_none() {
861 let json = r#"{
862 "title": "Old Thread",
863 "messages": [],
864 "updated_at": "2024-01-01T00:00:00Z"
865 }"#;
866
867 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
868
869 assert!(
870 db_thread.ui_scroll_position.is_none(),
871 "Legacy threads without scroll_position field should default to None"
872 );
873 }
874
875 #[gpui::test]
876 async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
877 let database = ThreadsDatabase::new(cx.executor()).unwrap();
878
879 let thread_id = session_id("thread-with-scroll");
880
881 let mut thread = make_thread(
882 "Thread With Scroll",
883 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
884 );
885 thread.ui_scroll_position = Some(SerializedScrollPosition {
886 item_ix: 42,
887 offset_in_item: 13.5,
888 });
889
890 database
891 .save_thread(thread_id.clone(), thread, PathList::default())
892 .await
893 .unwrap();
894
895 let loaded = database
896 .load_thread(thread_id)
897 .await
898 .unwrap()
899 .expect("thread should exist");
900
901 let scroll = loaded
902 .ui_scroll_position
903 .expect("scroll_position should be restored");
904 assert_eq!(scroll.item_ix, 42);
905 assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
906 }
907}