db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    /// The workspace folder paths this thread was created against, sorted
 36    /// lexicographically. Used for grouping threads by project in the sidebar.
 37    pub folder_paths: PathList,
 38}
 39
 40#[derive(Debug, Serialize, Deserialize)]
 41pub struct DbThread {
 42    pub title: SharedString,
 43    pub messages: Vec<DbMessage>,
 44    pub updated_at: DateTime<Utc>,
 45    #[serde(default)]
 46    pub detailed_summary: Option<SharedString>,
 47    #[serde(default)]
 48    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 49    #[serde(default)]
 50    pub cumulative_token_usage: language_model::TokenUsage,
 51    #[serde(default)]
 52    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 53    #[serde(default)]
 54    pub model: Option<DbLanguageModel>,
 55    #[serde(default)]
 56    pub profile: Option<AgentProfileId>,
 57    #[serde(default)]
 58    pub imported: bool,
 59    #[serde(default)]
 60    pub subagent_context: Option<crate::SubagentContext>,
 61    #[serde(default)]
 62    pub speed: Option<Speed>,
 63    #[serde(default)]
 64    pub thinking_enabled: bool,
 65    #[serde(default)]
 66    pub thinking_effort: Option<String>,
 67}
 68
 69#[derive(Debug, Clone, Serialize, Deserialize)]
 70pub struct SharedThread {
 71    pub title: SharedString,
 72    pub messages: Vec<DbMessage>,
 73    pub updated_at: DateTime<Utc>,
 74    #[serde(default)]
 75    pub model: Option<DbLanguageModel>,
 76    pub version: String,
 77}
 78
 79impl SharedThread {
 80    pub const VERSION: &'static str = "1.0.0";
 81
 82    pub fn from_db_thread(thread: &DbThread) -> Self {
 83        Self {
 84            title: thread.title.clone(),
 85            messages: thread.messages.clone(),
 86            updated_at: thread.updated_at,
 87            model: thread.model.clone(),
 88            version: Self::VERSION.to_string(),
 89        }
 90    }
 91
 92    pub fn to_db_thread(self) -> DbThread {
 93        DbThread {
 94            title: format!("🔗 {}", self.title).into(),
 95            messages: self.messages,
 96            updated_at: self.updated_at,
 97            detailed_summary: None,
 98            initial_project_snapshot: None,
 99            cumulative_token_usage: Default::default(),
100            request_token_usage: Default::default(),
101            model: self.model,
102            profile: None,
103            imported: true,
104            subagent_context: None,
105            speed: None,
106            thinking_enabled: false,
107            thinking_effort: None,
108        }
109    }
110
111    pub fn to_bytes(&self) -> Result<Vec<u8>> {
112        const COMPRESSION_LEVEL: i32 = 3;
113        let json = serde_json::to_vec(self)?;
114        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
115        Ok(compressed)
116    }
117
118    pub fn from_bytes(data: &[u8]) -> Result<Self> {
119        let decompressed = zstd::decode_all(data)?;
120        Ok(serde_json::from_slice(&decompressed)?)
121    }
122}
123
124impl DbThread {
125    pub const VERSION: &'static str = "0.3.0";
126
127    pub fn from_json(json: &[u8]) -> Result<Self> {
128        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
129        match saved_thread_json.get("version") {
130            Some(serde_json::Value::String(version)) => match version.as_str() {
131                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
132                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
133                    json,
134                )?),
135            },
136            _ => {
137                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
138            }
139        }
140    }
141
142    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
143        let mut messages = Vec::new();
144        let mut request_token_usage = HashMap::default();
145
146        let mut last_user_message_id = None;
147        for (ix, msg) in thread.messages.into_iter().enumerate() {
148            let message = match msg.role {
149                language_model::Role::User => {
150                    let mut content = Vec::new();
151
152                    // Convert segments to content
153                    for segment in msg.segments {
154                        match segment {
155                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
156                                content.push(UserMessageContent::Text(text));
157                            }
158                            crate::legacy_thread::SerializedMessageSegment::Thinking {
159                                text,
160                                ..
161                            } => {
162                                // User messages don't have thinking segments, but handle gracefully
163                                content.push(UserMessageContent::Text(text));
164                            }
165                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
166                                ..
167                            } => {
168                                // User messages don't have redacted thinking, skip.
169                            }
170                        }
171                    }
172
173                    // If no content was added, add context as text if available
174                    if content.is_empty() && !msg.context.is_empty() {
175                        content.push(UserMessageContent::Text(msg.context));
176                    }
177
178                    let id = UserMessageId::new();
179                    last_user_message_id = Some(id.clone());
180
181                    crate::Message::User(UserMessage {
182                        // MessageId from old format can't be meaningfully converted, so generate a new one
183                        id,
184                        content,
185                    })
186                }
187                language_model::Role::Assistant => {
188                    let mut content = Vec::new();
189
190                    // Convert segments to content
191                    for segment in msg.segments {
192                        match segment {
193                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
194                                content.push(AgentMessageContent::Text(text));
195                            }
196                            crate::legacy_thread::SerializedMessageSegment::Thinking {
197                                text,
198                                signature,
199                            } => {
200                                content.push(AgentMessageContent::Thinking { text, signature });
201                            }
202                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
203                                data,
204                            } => {
205                                content.push(AgentMessageContent::RedactedThinking(data));
206                            }
207                        }
208                    }
209
210                    // Convert tool uses
211                    let mut tool_names_by_id = HashMap::default();
212                    for tool_use in msg.tool_uses {
213                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
214                        content.push(AgentMessageContent::ToolUse(
215                            language_model::LanguageModelToolUse {
216                                id: tool_use.id,
217                                name: tool_use.name.into(),
218                                raw_input: serde_json::to_string(&tool_use.input)
219                                    .unwrap_or_default(),
220                                input: tool_use.input,
221                                is_input_complete: true,
222                                thought_signature: None,
223                            },
224                        ));
225                    }
226
227                    // Convert tool results
228                    let mut tool_results = IndexMap::default();
229                    for tool_result in msg.tool_results {
230                        let name = tool_names_by_id
231                            .remove(&tool_result.tool_use_id)
232                            .unwrap_or_else(|| SharedString::from("unknown"));
233                        tool_results.insert(
234                            tool_result.tool_use_id.clone(),
235                            language_model::LanguageModelToolResult {
236                                tool_use_id: tool_result.tool_use_id,
237                                tool_name: name.into(),
238                                is_error: tool_result.is_error,
239                                content: tool_result.content,
240                                output: tool_result.output,
241                            },
242                        );
243                    }
244
245                    if let Some(last_user_message_id) = &last_user_message_id
246                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
247                    {
248                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
249                    }
250
251                    crate::Message::Agent(AgentMessage {
252                        content,
253                        tool_results,
254                        reasoning_details: None,
255                    })
256                }
257                language_model::Role::System => {
258                    // Skip system messages as they're not supported in the new format
259                    continue;
260                }
261            };
262
263            messages.push(message);
264        }
265
266        Ok(Self {
267            title: thread.summary,
268            messages,
269            updated_at: thread.updated_at,
270            detailed_summary: match thread.detailed_summary_state {
271                crate::legacy_thread::DetailedSummaryState::NotGenerated
272                | crate::legacy_thread::DetailedSummaryState::Generating => None,
273                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
274            },
275            initial_project_snapshot: thread.initial_project_snapshot,
276            cumulative_token_usage: thread.cumulative_token_usage,
277            request_token_usage,
278            model: thread.model,
279            profile: thread.profile,
280            imported: false,
281            subagent_context: None,
282            speed: None,
283            thinking_enabled: false,
284            thinking_effort: None,
285        })
286    }
287}
288
289#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
290pub enum DataType {
291    #[serde(rename = "json")]
292    Json,
293    #[serde(rename = "zstd")]
294    Zstd,
295}
296
297impl Bind for DataType {
298    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
299        let value = match self {
300            DataType::Json => "json",
301            DataType::Zstd => "zstd",
302        };
303        value.bind(statement, start_index)
304    }
305}
306
307impl Column for DataType {
308    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
309        let (value, next_index) = String::column(statement, start_index)?;
310        let data_type = match value.as_str() {
311            "json" => DataType::Json,
312            "zstd" => DataType::Zstd,
313            _ => anyhow::bail!("Unknown data type: {}", value),
314        };
315        Ok((data_type, next_index))
316    }
317}
318
319pub(crate) struct ThreadsDatabase {
320    executor: BackgroundExecutor,
321    connection: Arc<Mutex<Connection>>,
322}
323
324struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
325
326impl Global for GlobalThreadsDatabase {}
327
328impl ThreadsDatabase {
329    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
330        if cx.has_global::<GlobalThreadsDatabase>() {
331            return cx.global::<GlobalThreadsDatabase>().0.clone();
332        }
333        let executor = cx.background_executor().clone();
334        let task = executor
335            .spawn({
336                let executor = executor.clone();
337                async move {
338                    match ThreadsDatabase::new(executor) {
339                        Ok(db) => Ok(Arc::new(db)),
340                        Err(err) => Err(Arc::new(err)),
341                    }
342                }
343            })
344            .shared();
345
346        cx.set_global(GlobalThreadsDatabase(task.clone()));
347        task
348    }
349
350    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
351        let connection = if *ZED_STATELESS {
352            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
353        } else if cfg!(any(feature = "test-support", test)) {
354            // rust stores the name of the test on the current thread.
355            // We use this to automatically create a database that will
356            // be shared within the test (for the test_retrieve_old_thread)
357            // but not with concurrent tests.
358            let thread = std::thread::current();
359            let test_name = thread.name();
360            Connection::open_memory(Some(&format!(
361                "THREAD_FALLBACK_{}",
362                test_name.unwrap_or_default()
363            )))
364        } else {
365            let threads_dir = paths::data_dir().join("threads");
366            std::fs::create_dir_all(&threads_dir)?;
367            let sqlite_path = threads_dir.join("threads.db");
368            Connection::open_file(&sqlite_path.to_string_lossy())
369        };
370
371        connection.exec(indoc! {"
372            CREATE TABLE IF NOT EXISTS threads (
373                id TEXT PRIMARY KEY,
374                summary TEXT NOT NULL,
375                updated_at TEXT NOT NULL,
376                data_type TEXT NOT NULL,
377                data BLOB NOT NULL
378            )
379        "})?()
380        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
381
382        if let Ok(mut s) = connection.exec(indoc! {"
383            ALTER TABLE threads ADD COLUMN parent_id TEXT
384        "})
385        {
386            s().ok();
387        }
388
389        if let Ok(mut s) = connection.exec(indoc! {"
390            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
391            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
392        "})
393        {
394            s().ok();
395        }
396
397        let db = Self {
398            executor,
399            connection: Arc::new(Mutex::new(connection)),
400        };
401
402        Ok(db)
403    }
404
405    fn save_thread_sync(
406        connection: &Arc<Mutex<Connection>>,
407        id: acp::SessionId,
408        thread: DbThread,
409        folder_paths: &PathList,
410    ) -> Result<()> {
411        const COMPRESSION_LEVEL: i32 = 3;
412
413        #[derive(Serialize)]
414        struct SerializedThread {
415            #[serde(flatten)]
416            thread: DbThread,
417            version: &'static str,
418        }
419
420        let title = thread.title.to_string();
421        let updated_at = thread.updated_at.to_rfc3339();
422        let parent_id = thread
423            .subagent_context
424            .as_ref()
425            .map(|ctx| ctx.parent_thread_id.0.clone());
426        let serialized_folder_paths = folder_paths.serialize();
427        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
428            if folder_paths.is_empty() {
429                (None, None)
430            } else {
431                (
432                    Some(serialized_folder_paths.paths),
433                    Some(serialized_folder_paths.order),
434                )
435            };
436        let json_data = serde_json::to_string(&SerializedThread {
437            thread,
438            version: DbThread::VERSION,
439        })?;
440
441        let connection = connection.lock();
442
443        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
444        let data_type = DataType::Zstd;
445        let data = compressed;
446
447        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
448            INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
449        "})?;
450
451        insert((
452            id.0,
453            parent_id,
454            folder_paths_str,
455            folder_paths_order_str,
456            title,
457            updated_at,
458            data_type,
459            data,
460        ))?;
461
462        Ok(())
463    }
464
465    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
466        let connection = self.connection.clone();
467
468        self.executor.spawn(async move {
469            let connection = connection.lock();
470
471            let mut select = connection
472                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
473                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
474            "})?;
475
476            let rows = select(())?;
477            let mut threads = Vec::new();
478
479            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
480                let folder_paths = folder_paths
481                    .map(|paths| {
482                        PathList::deserialize(&util::path_list::SerializedPathList {
483                            paths,
484                            order: folder_paths_order.unwrap_or_default(),
485                        })
486                    })
487                    .unwrap_or_default();
488                threads.push(DbThreadMetadata {
489                    id: acp::SessionId::new(id),
490                    parent_session_id: parent_id.map(acp::SessionId::new),
491                    title: summary.into(),
492                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
493                    folder_paths,
494                });
495            }
496
497            Ok(threads)
498        })
499    }
500
501    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
502        let connection = self.connection.clone();
503
504        self.executor.spawn(async move {
505            let connection = connection.lock();
506            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
507                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
508            "})?;
509
510            let rows = select(id.0)?;
511            if let Some((data_type, data)) = rows.into_iter().next() {
512                let json_data = match data_type {
513                    DataType::Zstd => {
514                        let decompressed = zstd::decode_all(&data[..])?;
515                        String::from_utf8(decompressed)?
516                    }
517                    DataType::Json => String::from_utf8(data)?,
518                };
519                let thread = DbThread::from_json(json_data.as_bytes())?;
520                Ok(Some(thread))
521            } else {
522                Ok(None)
523            }
524        })
525    }
526
527    pub fn save_thread(
528        &self,
529        id: acp::SessionId,
530        thread: DbThread,
531        folder_paths: PathList,
532    ) -> Task<Result<()>> {
533        let connection = self.connection.clone();
534
535        self.executor
536            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
537    }
538
539    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
540        let connection = self.connection.clone();
541
542        self.executor.spawn(async move {
543            let connection = connection.lock();
544
545            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
546                DELETE FROM threads WHERE id = ?
547            "})?;
548
549            delete(id.0)?;
550
551            Ok(())
552        })
553    }
554
555    pub fn delete_threads(&self) -> Task<Result<()>> {
556        let connection = self.connection.clone();
557
558        self.executor.spawn(async move {
559            let connection = connection.lock();
560
561            let mut delete = connection.exec_bound::<()>(indoc! {"
562                DELETE FROM threads
563            "})?;
564
565            delete(())?;
566
567            Ok(())
568        })
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use chrono::{DateTime, TimeZone, Utc};
576    use collections::HashMap;
577    use gpui::TestAppContext;
578    use std::sync::Arc;
579
580    #[test]
581    fn test_shared_thread_roundtrip() {
582        let original = SharedThread {
583            title: "Test Thread".into(),
584            messages: vec![],
585            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
586            model: None,
587            version: SharedThread::VERSION.to_string(),
588        };
589
590        let bytes = original.to_bytes().expect("Failed to serialize");
591        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
592
593        assert_eq!(restored.title, original.title);
594        assert_eq!(restored.version, original.version);
595        assert_eq!(restored.updated_at, original.updated_at);
596    }
597
598    #[test]
599    fn test_imported_flag_defaults_to_false() {
600        // Simulate deserializing a thread without the imported field (backwards compatibility).
601        let json = r#"{
602            "title": "Old Thread",
603            "messages": [],
604            "updated_at": "2024-01-01T00:00:00Z"
605        }"#;
606
607        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
608
609        assert!(
610            !db_thread.imported,
611            "Legacy threads without imported field should default to false"
612        );
613    }
614
615    fn session_id(value: &str) -> acp::SessionId {
616        acp::SessionId::new(Arc::<str>::from(value))
617    }
618
619    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
620        DbThread {
621            title: title.to_string().into(),
622            messages: Vec::new(),
623            updated_at,
624            detailed_summary: None,
625            initial_project_snapshot: None,
626            cumulative_token_usage: Default::default(),
627            request_token_usage: HashMap::default(),
628            model: None,
629            profile: None,
630            imported: false,
631            subagent_context: None,
632            speed: None,
633            thinking_enabled: false,
634            thinking_effort: None,
635        }
636    }
637
638    #[gpui::test]
639    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
640        let database = ThreadsDatabase::new(cx.executor()).unwrap();
641
642        let older_id = session_id("thread-a");
643        let newer_id = session_id("thread-b");
644
645        let older_thread = make_thread(
646            "Thread A",
647            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
648        );
649        let newer_thread = make_thread(
650            "Thread B",
651            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
652        );
653
654        database
655            .save_thread(older_id.clone(), older_thread, PathList::default())
656            .await
657            .unwrap();
658        database
659            .save_thread(newer_id.clone(), newer_thread, PathList::default())
660            .await
661            .unwrap();
662
663        let entries = database.list_threads().await.unwrap();
664        assert_eq!(entries.len(), 2);
665        assert_eq!(entries[0].id, newer_id);
666        assert_eq!(entries[1].id, older_id);
667    }
668
669    #[gpui::test]
670    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
671        let database = ThreadsDatabase::new(cx.executor()).unwrap();
672
673        let thread_id = session_id("thread-a");
674        let original_thread = make_thread(
675            "Thread A",
676            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
677        );
678        let updated_thread = make_thread(
679            "Thread B",
680            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
681        );
682
683        database
684            .save_thread(thread_id.clone(), original_thread, PathList::default())
685            .await
686            .unwrap();
687        database
688            .save_thread(thread_id.clone(), updated_thread, PathList::default())
689            .await
690            .unwrap();
691
692        let entries = database.list_threads().await.unwrap();
693        assert_eq!(entries.len(), 1);
694        assert_eq!(entries[0].id, thread_id);
695        assert_eq!(entries[0].title.as_ref(), "Thread B");
696        assert_eq!(
697            entries[0].updated_at,
698            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
699        );
700    }
701
702    #[test]
703    fn test_subagent_context_defaults_to_none() {
704        let json = r#"{
705            "title": "Old Thread",
706            "messages": [],
707            "updated_at": "2024-01-01T00:00:00Z"
708        }"#;
709
710        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
711
712        assert!(
713            db_thread.subagent_context.is_none(),
714            "Legacy threads without subagent_context should default to None"
715        );
716    }
717
718    #[gpui::test]
719    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
720        let database = ThreadsDatabase::new(cx.executor()).unwrap();
721
722        let parent_id = session_id("parent-thread");
723        let child_id = session_id("child-thread");
724
725        let mut child_thread = make_thread(
726            "Subagent Thread",
727            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
728        );
729        child_thread.subagent_context = Some(crate::SubagentContext {
730            parent_thread_id: parent_id.clone(),
731            depth: 2,
732        });
733
734        database
735            .save_thread(child_id.clone(), child_thread, PathList::default())
736            .await
737            .unwrap();
738
739        let loaded = database
740            .load_thread(child_id)
741            .await
742            .unwrap()
743            .expect("thread should exist");
744
745        let context = loaded
746            .subagent_context
747            .expect("subagent_context should be restored");
748        assert_eq!(context.parent_thread_id, parent_id);
749        assert_eq!(context.depth, 2);
750    }
751
752    #[gpui::test]
753    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
754        let database = ThreadsDatabase::new(cx.executor()).unwrap();
755
756        let thread_id = session_id("regular-thread");
757        let thread = make_thread(
758            "Regular Thread",
759            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
760        );
761
762        database
763            .save_thread(thread_id.clone(), thread, PathList::default())
764            .await
765            .unwrap();
766
767        let loaded = database
768            .load_thread(thread_id)
769            .await
770            .unwrap()
771            .expect("thread should exist");
772
773        assert!(
774            loaded.subagent_context.is_none(),
775            "Regular threads should have no subagent_context"
776        );
777    }
778
779    #[gpui::test]
780    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
781        let database = ThreadsDatabase::new(cx.executor()).unwrap();
782
783        let thread_id = session_id("folder-thread");
784        let thread = make_thread(
785            "Folder Thread",
786            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
787        );
788
789        let folder_paths = PathList::new(&[
790            std::path::PathBuf::from("/home/user/project-a"),
791            std::path::PathBuf::from("/home/user/project-b"),
792        ]);
793
794        database
795            .save_thread(thread_id.clone(), thread, folder_paths.clone())
796            .await
797            .unwrap();
798
799        let threads = database.list_threads().await.unwrap();
800        assert_eq!(threads.len(), 1);
801        assert_eq!(threads[0].folder_paths, folder_paths);
802    }
803
804    #[gpui::test]
805    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
806        let database = ThreadsDatabase::new(cx.executor()).unwrap();
807
808        let thread_id = session_id("no-folder-thread");
809        let thread = make_thread(
810            "No Folder Thread",
811            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
812        );
813
814        database
815            .save_thread(thread_id.clone(), thread, PathList::default())
816            .await
817            .unwrap();
818
819        let threads = database.list_threads().await.unwrap();
820        assert_eq!(threads.len(), 1);
821        assert!(threads[0].folder_paths.is_empty());
822    }
823}