db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    /// The workspace folder paths this thread was created against, sorted
 36    /// lexicographically. Used for grouping threads by project in the sidebar.
 37    pub folder_paths: PathList,
 38}
 39
 40#[derive(Debug, Serialize, Deserialize)]
 41pub struct DbThread {
 42    pub title: SharedString,
 43    pub messages: Vec<DbMessage>,
 44    pub updated_at: DateTime<Utc>,
 45    #[serde(default)]
 46    pub detailed_summary: Option<SharedString>,
 47    #[serde(default)]
 48    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 49    #[serde(default)]
 50    pub cumulative_token_usage: language_model::TokenUsage,
 51    #[serde(default)]
 52    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 53    #[serde(default)]
 54    pub model: Option<DbLanguageModel>,
 55    #[serde(default)]
 56    pub profile: Option<AgentProfileId>,
 57    #[serde(default)]
 58    pub imported: bool,
 59    #[serde(default)]
 60    pub subagent_context: Option<crate::SubagentContext>,
 61    #[serde(default)]
 62    pub speed: Option<Speed>,
 63    #[serde(default)]
 64    pub thinking_enabled: bool,
 65    #[serde(default)]
 66    pub thinking_effort: Option<String>,
 67    #[serde(default)]
 68    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 69}
 70
 71#[derive(Debug, Clone, Serialize, Deserialize)]
 72pub struct SharedThread {
 73    pub title: SharedString,
 74    pub messages: Vec<DbMessage>,
 75    pub updated_at: DateTime<Utc>,
 76    #[serde(default)]
 77    pub model: Option<DbLanguageModel>,
 78    pub version: String,
 79}
 80
 81impl SharedThread {
 82    pub const VERSION: &'static str = "1.0.0";
 83
 84    pub fn from_db_thread(thread: &DbThread) -> Self {
 85        Self {
 86            title: thread.title.clone(),
 87            messages: thread.messages.clone(),
 88            updated_at: thread.updated_at,
 89            model: thread.model.clone(),
 90            version: Self::VERSION.to_string(),
 91        }
 92    }
 93
 94    pub fn to_db_thread(self) -> DbThread {
 95        DbThread {
 96            title: format!("🔗 {}", self.title).into(),
 97            messages: self.messages,
 98            updated_at: self.updated_at,
 99            detailed_summary: None,
100            initial_project_snapshot: None,
101            cumulative_token_usage: Default::default(),
102            request_token_usage: Default::default(),
103            model: self.model,
104            profile: None,
105            imported: true,
106            subagent_context: None,
107            speed: None,
108            thinking_enabled: false,
109            thinking_effort: None,
110            draft_prompt: None,
111        }
112    }
113
114    pub fn to_bytes(&self) -> Result<Vec<u8>> {
115        const COMPRESSION_LEVEL: i32 = 3;
116        let json = serde_json::to_vec(self)?;
117        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
118        Ok(compressed)
119    }
120
121    pub fn from_bytes(data: &[u8]) -> Result<Self> {
122        let decompressed = zstd::decode_all(data)?;
123        Ok(serde_json::from_slice(&decompressed)?)
124    }
125}
126
127impl DbThread {
128    pub const VERSION: &'static str = "0.3.0";
129
130    pub fn from_json(json: &[u8]) -> Result<Self> {
131        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
132        match saved_thread_json.get("version") {
133            Some(serde_json::Value::String(version)) => match version.as_str() {
134                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
135                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
136                    json,
137                )?),
138            },
139            _ => {
140                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
141            }
142        }
143    }
144
145    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
146        let mut messages = Vec::new();
147        let mut request_token_usage = HashMap::default();
148
149        let mut last_user_message_id = None;
150        for (ix, msg) in thread.messages.into_iter().enumerate() {
151            let message = match msg.role {
152                language_model::Role::User => {
153                    let mut content = Vec::new();
154
155                    // Convert segments to content
156                    for segment in msg.segments {
157                        match segment {
158                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
159                                content.push(UserMessageContent::Text(text));
160                            }
161                            crate::legacy_thread::SerializedMessageSegment::Thinking {
162                                text,
163                                ..
164                            } => {
165                                // User messages don't have thinking segments, but handle gracefully
166                                content.push(UserMessageContent::Text(text));
167                            }
168                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
169                                ..
170                            } => {
171                                // User messages don't have redacted thinking, skip.
172                            }
173                        }
174                    }
175
176                    // If no content was added, add context as text if available
177                    if content.is_empty() && !msg.context.is_empty() {
178                        content.push(UserMessageContent::Text(msg.context));
179                    }
180
181                    let id = UserMessageId::new();
182                    last_user_message_id = Some(id.clone());
183
184                    crate::Message::User(UserMessage {
185                        // MessageId from old format can't be meaningfully converted, so generate a new one
186                        id,
187                        content,
188                    })
189                }
190                language_model::Role::Assistant => {
191                    let mut content = Vec::new();
192
193                    // Convert segments to content
194                    for segment in msg.segments {
195                        match segment {
196                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
197                                content.push(AgentMessageContent::Text(text));
198                            }
199                            crate::legacy_thread::SerializedMessageSegment::Thinking {
200                                text,
201                                signature,
202                            } => {
203                                content.push(AgentMessageContent::Thinking { text, signature });
204                            }
205                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
206                                data,
207                            } => {
208                                content.push(AgentMessageContent::RedactedThinking(data));
209                            }
210                        }
211                    }
212
213                    // Convert tool uses
214                    let mut tool_names_by_id = HashMap::default();
215                    for tool_use in msg.tool_uses {
216                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
217                        content.push(AgentMessageContent::ToolUse(
218                            language_model::LanguageModelToolUse {
219                                id: tool_use.id,
220                                name: tool_use.name.into(),
221                                raw_input: serde_json::to_string(&tool_use.input)
222                                    .unwrap_or_default(),
223                                input: tool_use.input,
224                                is_input_complete: true,
225                                thought_signature: None,
226                            },
227                        ));
228                    }
229
230                    // Convert tool results
231                    let mut tool_results = IndexMap::default();
232                    for tool_result in msg.tool_results {
233                        let name = tool_names_by_id
234                            .remove(&tool_result.tool_use_id)
235                            .unwrap_or_else(|| SharedString::from("unknown"));
236                        tool_results.insert(
237                            tool_result.tool_use_id.clone(),
238                            language_model::LanguageModelToolResult {
239                                tool_use_id: tool_result.tool_use_id,
240                                tool_name: name.into(),
241                                is_error: tool_result.is_error,
242                                content: tool_result.content,
243                                output: tool_result.output,
244                            },
245                        );
246                    }
247
248                    if let Some(last_user_message_id) = &last_user_message_id
249                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
250                    {
251                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
252                    }
253
254                    crate::Message::Agent(AgentMessage {
255                        content,
256                        tool_results,
257                        reasoning_details: None,
258                    })
259                }
260                language_model::Role::System => {
261                    // Skip system messages as they're not supported in the new format
262                    continue;
263                }
264            };
265
266            messages.push(message);
267        }
268
269        Ok(Self {
270            title: thread.summary,
271            messages,
272            updated_at: thread.updated_at,
273            detailed_summary: match thread.detailed_summary_state {
274                crate::legacy_thread::DetailedSummaryState::NotGenerated
275                | crate::legacy_thread::DetailedSummaryState::Generating => None,
276                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
277            },
278            initial_project_snapshot: thread.initial_project_snapshot,
279            cumulative_token_usage: thread.cumulative_token_usage,
280            request_token_usage,
281            model: thread.model,
282            profile: thread.profile,
283            imported: false,
284            subagent_context: None,
285            speed: None,
286            thinking_enabled: false,
287            thinking_effort: None,
288            draft_prompt: None,
289        })
290    }
291}
292
293#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
294pub enum DataType {
295    #[serde(rename = "json")]
296    Json,
297    #[serde(rename = "zstd")]
298    Zstd,
299}
300
301impl Bind for DataType {
302    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
303        let value = match self {
304            DataType::Json => "json",
305            DataType::Zstd => "zstd",
306        };
307        value.bind(statement, start_index)
308    }
309}
310
311impl Column for DataType {
312    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
313        let (value, next_index) = String::column(statement, start_index)?;
314        let data_type = match value.as_str() {
315            "json" => DataType::Json,
316            "zstd" => DataType::Zstd,
317            _ => anyhow::bail!("Unknown data type: {}", value),
318        };
319        Ok((data_type, next_index))
320    }
321}
322
323pub(crate) struct ThreadsDatabase {
324    executor: BackgroundExecutor,
325    connection: Arc<Mutex<Connection>>,
326}
327
328struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
329
330impl Global for GlobalThreadsDatabase {}
331
332impl ThreadsDatabase {
333    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
334        if cx.has_global::<GlobalThreadsDatabase>() {
335            return cx.global::<GlobalThreadsDatabase>().0.clone();
336        }
337        let executor = cx.background_executor().clone();
338        let task = executor
339            .spawn({
340                let executor = executor.clone();
341                async move {
342                    match ThreadsDatabase::new(executor) {
343                        Ok(db) => Ok(Arc::new(db)),
344                        Err(err) => Err(Arc::new(err)),
345                    }
346                }
347            })
348            .shared();
349
350        cx.set_global(GlobalThreadsDatabase(task.clone()));
351        task
352    }
353
354    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
355        let connection = if *ZED_STATELESS {
356            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
357        } else if cfg!(any(feature = "test-support", test)) {
358            // rust stores the name of the test on the current thread.
359            // We use this to automatically create a database that will
360            // be shared within the test (for the test_retrieve_old_thread)
361            // but not with concurrent tests.
362            let thread = std::thread::current();
363            let test_name = thread.name();
364            Connection::open_memory(Some(&format!(
365                "THREAD_FALLBACK_{}",
366                test_name.unwrap_or_default()
367            )))
368        } else {
369            let threads_dir = paths::data_dir().join("threads");
370            std::fs::create_dir_all(&threads_dir)?;
371            let sqlite_path = threads_dir.join("threads.db");
372            Connection::open_file(&sqlite_path.to_string_lossy())
373        };
374
375        connection.exec(indoc! {"
376            CREATE TABLE IF NOT EXISTS threads (
377                id TEXT PRIMARY KEY,
378                summary TEXT NOT NULL,
379                updated_at TEXT NOT NULL,
380                data_type TEXT NOT NULL,
381                data BLOB NOT NULL
382            )
383        "})?()
384        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
385
386        if let Ok(mut s) = connection.exec(indoc! {"
387            ALTER TABLE threads ADD COLUMN parent_id TEXT
388        "})
389        {
390            s().ok();
391        }
392
393        if let Ok(mut s) = connection.exec(indoc! {"
394            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
395            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
396        "})
397        {
398            s().ok();
399        }
400
401        let db = Self {
402            executor,
403            connection: Arc::new(Mutex::new(connection)),
404        };
405
406        Ok(db)
407    }
408
409    fn save_thread_sync(
410        connection: &Arc<Mutex<Connection>>,
411        id: acp::SessionId,
412        thread: DbThread,
413        folder_paths: &PathList,
414    ) -> Result<()> {
415        const COMPRESSION_LEVEL: i32 = 3;
416
417        #[derive(Serialize)]
418        struct SerializedThread {
419            #[serde(flatten)]
420            thread: DbThread,
421            version: &'static str,
422        }
423
424        let title = thread.title.to_string();
425        let updated_at = thread.updated_at.to_rfc3339();
426        let parent_id = thread
427            .subagent_context
428            .as_ref()
429            .map(|ctx| ctx.parent_thread_id.0.clone());
430        let serialized_folder_paths = folder_paths.serialize();
431        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
432            if folder_paths.is_empty() {
433                (None, None)
434            } else {
435                (
436                    Some(serialized_folder_paths.paths),
437                    Some(serialized_folder_paths.order),
438                )
439            };
440        let json_data = serde_json::to_string(&SerializedThread {
441            thread,
442            version: DbThread::VERSION,
443        })?;
444
445        let connection = connection.lock();
446
447        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
448        let data_type = DataType::Zstd;
449        let data = compressed;
450
451        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
452            INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
453        "})?;
454
455        insert((
456            id.0,
457            parent_id,
458            folder_paths_str,
459            folder_paths_order_str,
460            title,
461            updated_at,
462            data_type,
463            data,
464        ))?;
465
466        Ok(())
467    }
468
469    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
470        let connection = self.connection.clone();
471
472        self.executor.spawn(async move {
473            let connection = connection.lock();
474
475            let mut select = connection
476                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
477                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
478            "})?;
479
480            let rows = select(())?;
481            let mut threads = Vec::new();
482
483            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
484                let folder_paths = folder_paths
485                    .map(|paths| {
486                        PathList::deserialize(&util::path_list::SerializedPathList {
487                            paths,
488                            order: folder_paths_order.unwrap_or_default(),
489                        })
490                    })
491                    .unwrap_or_default();
492                threads.push(DbThreadMetadata {
493                    id: acp::SessionId::new(id),
494                    parent_session_id: parent_id.map(acp::SessionId::new),
495                    title: summary.into(),
496                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
497                    folder_paths,
498                });
499            }
500
501            Ok(threads)
502        })
503    }
504
505    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
506        let connection = self.connection.clone();
507
508        self.executor.spawn(async move {
509            let connection = connection.lock();
510            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
511                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
512            "})?;
513
514            let rows = select(id.0)?;
515            if let Some((data_type, data)) = rows.into_iter().next() {
516                let json_data = match data_type {
517                    DataType::Zstd => {
518                        let decompressed = zstd::decode_all(&data[..])?;
519                        String::from_utf8(decompressed)?
520                    }
521                    DataType::Json => String::from_utf8(data)?,
522                };
523                let thread = DbThread::from_json(json_data.as_bytes())?;
524                Ok(Some(thread))
525            } else {
526                Ok(None)
527            }
528        })
529    }
530
531    pub fn save_thread(
532        &self,
533        id: acp::SessionId,
534        thread: DbThread,
535        folder_paths: PathList,
536    ) -> Task<Result<()>> {
537        let connection = self.connection.clone();
538
539        self.executor
540            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
541    }
542
543    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
544        let connection = self.connection.clone();
545
546        self.executor.spawn(async move {
547            let connection = connection.lock();
548
549            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
550                DELETE FROM threads WHERE id = ?
551            "})?;
552
553            delete(id.0)?;
554
555            Ok(())
556        })
557    }
558
559    pub fn delete_threads(&self) -> Task<Result<()>> {
560        let connection = self.connection.clone();
561
562        self.executor.spawn(async move {
563            let connection = connection.lock();
564
565            let mut delete = connection.exec_bound::<()>(indoc! {"
566                DELETE FROM threads
567            "})?;
568
569            delete(())?;
570
571            Ok(())
572        })
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use chrono::{DateTime, TimeZone, Utc};
580    use collections::HashMap;
581    use gpui::TestAppContext;
582    use std::sync::Arc;
583
584    #[test]
585    fn test_shared_thread_roundtrip() {
586        let original = SharedThread {
587            title: "Test Thread".into(),
588            messages: vec![],
589            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
590            model: None,
591            version: SharedThread::VERSION.to_string(),
592        };
593
594        let bytes = original.to_bytes().expect("Failed to serialize");
595        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
596
597        assert_eq!(restored.title, original.title);
598        assert_eq!(restored.version, original.version);
599        assert_eq!(restored.updated_at, original.updated_at);
600    }
601
602    #[test]
603    fn test_imported_flag_defaults_to_false() {
604        // Simulate deserializing a thread without the imported field (backwards compatibility).
605        let json = r#"{
606            "title": "Old Thread",
607            "messages": [],
608            "updated_at": "2024-01-01T00:00:00Z"
609        }"#;
610
611        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
612
613        assert!(
614            !db_thread.imported,
615            "Legacy threads without imported field should default to false"
616        );
617    }
618
619    fn session_id(value: &str) -> acp::SessionId {
620        acp::SessionId::new(Arc::<str>::from(value))
621    }
622
623    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
624        DbThread {
625            title: title.to_string().into(),
626            messages: Vec::new(),
627            updated_at,
628            detailed_summary: None,
629            initial_project_snapshot: None,
630            cumulative_token_usage: Default::default(),
631            request_token_usage: HashMap::default(),
632            model: None,
633            profile: None,
634            imported: false,
635            subagent_context: None,
636            speed: None,
637            thinking_enabled: false,
638            thinking_effort: None,
639            draft_prompt: None,
640        }
641    }
642
643    #[gpui::test]
644    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
645        let database = ThreadsDatabase::new(cx.executor()).unwrap();
646
647        let older_id = session_id("thread-a");
648        let newer_id = session_id("thread-b");
649
650        let older_thread = make_thread(
651            "Thread A",
652            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
653        );
654        let newer_thread = make_thread(
655            "Thread B",
656            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
657        );
658
659        database
660            .save_thread(older_id.clone(), older_thread, PathList::default())
661            .await
662            .unwrap();
663        database
664            .save_thread(newer_id.clone(), newer_thread, PathList::default())
665            .await
666            .unwrap();
667
668        let entries = database.list_threads().await.unwrap();
669        assert_eq!(entries.len(), 2);
670        assert_eq!(entries[0].id, newer_id);
671        assert_eq!(entries[1].id, older_id);
672    }
673
674    #[gpui::test]
675    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
676        let database = ThreadsDatabase::new(cx.executor()).unwrap();
677
678        let thread_id = session_id("thread-a");
679        let original_thread = make_thread(
680            "Thread A",
681            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
682        );
683        let updated_thread = make_thread(
684            "Thread B",
685            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
686        );
687
688        database
689            .save_thread(thread_id.clone(), original_thread, PathList::default())
690            .await
691            .unwrap();
692        database
693            .save_thread(thread_id.clone(), updated_thread, PathList::default())
694            .await
695            .unwrap();
696
697        let entries = database.list_threads().await.unwrap();
698        assert_eq!(entries.len(), 1);
699        assert_eq!(entries[0].id, thread_id);
700        assert_eq!(entries[0].title.as_ref(), "Thread B");
701        assert_eq!(
702            entries[0].updated_at,
703            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
704        );
705    }
706
707    #[test]
708    fn test_subagent_context_defaults_to_none() {
709        let json = r#"{
710            "title": "Old Thread",
711            "messages": [],
712            "updated_at": "2024-01-01T00:00:00Z"
713        }"#;
714
715        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
716
717        assert!(
718            db_thread.subagent_context.is_none(),
719            "Legacy threads without subagent_context should default to None"
720        );
721    }
722
723    #[test]
724    fn test_draft_prompt_defaults_to_none() {
725        let json = r#"{
726            "title": "Old Thread",
727            "messages": [],
728            "updated_at": "2024-01-01T00:00:00Z"
729        }"#;
730
731        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
732
733        assert!(
734            db_thread.draft_prompt.is_none(),
735            "Legacy threads without draft_prompt field should default to None"
736        );
737    }
738
739    #[gpui::test]
740    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
741        let database = ThreadsDatabase::new(cx.executor()).unwrap();
742
743        let parent_id = session_id("parent-thread");
744        let child_id = session_id("child-thread");
745
746        let mut child_thread = make_thread(
747            "Subagent Thread",
748            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
749        );
750        child_thread.subagent_context = Some(crate::SubagentContext {
751            parent_thread_id: parent_id.clone(),
752            depth: 2,
753        });
754
755        database
756            .save_thread(child_id.clone(), child_thread, PathList::default())
757            .await
758            .unwrap();
759
760        let loaded = database
761            .load_thread(child_id)
762            .await
763            .unwrap()
764            .expect("thread should exist");
765
766        let context = loaded
767            .subagent_context
768            .expect("subagent_context should be restored");
769        assert_eq!(context.parent_thread_id, parent_id);
770        assert_eq!(context.depth, 2);
771    }
772
773    #[gpui::test]
774    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
775        let database = ThreadsDatabase::new(cx.executor()).unwrap();
776
777        let thread_id = session_id("regular-thread");
778        let thread = make_thread(
779            "Regular Thread",
780            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
781        );
782
783        database
784            .save_thread(thread_id.clone(), thread, PathList::default())
785            .await
786            .unwrap();
787
788        let loaded = database
789            .load_thread(thread_id)
790            .await
791            .unwrap()
792            .expect("thread should exist");
793
794        assert!(
795            loaded.subagent_context.is_none(),
796            "Regular threads should have no subagent_context"
797        );
798    }
799
800    #[gpui::test]
801    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
802        let database = ThreadsDatabase::new(cx.executor()).unwrap();
803
804        let thread_id = session_id("folder-thread");
805        let thread = make_thread(
806            "Folder Thread",
807            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
808        );
809
810        let folder_paths = PathList::new(&[
811            std::path::PathBuf::from("/home/user/project-a"),
812            std::path::PathBuf::from("/home/user/project-b"),
813        ]);
814
815        database
816            .save_thread(thread_id.clone(), thread, folder_paths.clone())
817            .await
818            .unwrap();
819
820        let threads = database.list_threads().await.unwrap();
821        assert_eq!(threads.len(), 1);
822        assert_eq!(threads[0].folder_paths, folder_paths);
823    }
824
825    #[gpui::test]
826    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
827        let database = ThreadsDatabase::new(cx.executor()).unwrap();
828
829        let thread_id = session_id("no-folder-thread");
830        let thread = make_thread(
831            "No Folder Thread",
832            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
833        );
834
835        database
836            .save_thread(thread_id.clone(), thread, PathList::default())
837            .await
838            .unwrap();
839
840        let threads = database.list_threads().await.unwrap();
841        assert_eq!(threads.len(), 1);
842        assert!(threads[0].folder_paths.is_empty());
843    }
844}