db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    /// The workspace folder paths this thread was created against, sorted
 36    /// lexicographically. Used for grouping threads by project in the sidebar.
 37    pub folder_paths: PathList,
 38}
 39
 40#[derive(Debug, Serialize, Deserialize)]
 41pub struct DbThread {
 42    pub title: SharedString,
 43    pub messages: Vec<DbMessage>,
 44    pub updated_at: DateTime<Utc>,
 45    #[serde(default)]
 46    pub detailed_summary: Option<SharedString>,
 47    #[serde(default)]
 48    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 49    #[serde(default)]
 50    pub cumulative_token_usage: language_model::TokenUsage,
 51    #[serde(default)]
 52    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 53    #[serde(default)]
 54    pub model: Option<DbLanguageModel>,
 55    #[serde(default)]
 56    pub profile: Option<AgentProfileId>,
 57    #[serde(default)]
 58    pub imported: bool,
 59    #[serde(default)]
 60    pub subagent_context: Option<crate::SubagentContext>,
 61    #[serde(default)]
 62    pub speed: Option<Speed>,
 63    #[serde(default)]
 64    pub thinking_enabled: bool,
 65    #[serde(default)]
 66    pub thinking_effort: Option<String>,
 67    #[serde(default)]
 68    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 69    #[serde(default)]
 70    pub ui_scroll_position: Option<SerializedScrollPosition>,
 71}
 72
 73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 74pub struct SerializedScrollPosition {
 75    pub item_ix: usize,
 76    pub offset_in_item: f32,
 77}
 78
 79#[derive(Debug, Clone, Serialize, Deserialize)]
 80pub struct SharedThread {
 81    pub title: SharedString,
 82    pub messages: Vec<DbMessage>,
 83    pub updated_at: DateTime<Utc>,
 84    #[serde(default)]
 85    pub model: Option<DbLanguageModel>,
 86    pub version: String,
 87}
 88
 89impl SharedThread {
 90    pub const VERSION: &'static str = "1.0.0";
 91
 92    pub fn from_db_thread(thread: &DbThread) -> Self {
 93        Self {
 94            title: thread.title.clone(),
 95            messages: thread.messages.clone(),
 96            updated_at: thread.updated_at,
 97            model: thread.model.clone(),
 98            version: Self::VERSION.to_string(),
 99        }
100    }
101
102    pub fn to_db_thread(self) -> DbThread {
103        DbThread {
104            title: format!("🔗 {}", self.title).into(),
105            messages: self.messages,
106            updated_at: self.updated_at,
107            detailed_summary: None,
108            initial_project_snapshot: None,
109            cumulative_token_usage: Default::default(),
110            request_token_usage: Default::default(),
111            model: self.model,
112            profile: None,
113            imported: true,
114            subagent_context: None,
115            speed: None,
116            thinking_enabled: false,
117            thinking_effort: None,
118            draft_prompt: None,
119            ui_scroll_position: None,
120        }
121    }
122
123    pub fn to_bytes(&self) -> Result<Vec<u8>> {
124        const COMPRESSION_LEVEL: i32 = 3;
125        let json = serde_json::to_vec(self)?;
126        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
127        Ok(compressed)
128    }
129
130    pub fn from_bytes(data: &[u8]) -> Result<Self> {
131        let decompressed = zstd::decode_all(data)?;
132        Ok(serde_json::from_slice(&decompressed)?)
133    }
134}
135
136impl DbThread {
137    pub const VERSION: &'static str = "0.3.0";
138
139    pub fn from_json(json: &[u8]) -> Result<Self> {
140        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
141        match saved_thread_json.get("version") {
142            Some(serde_json::Value::String(version)) => match version.as_str() {
143                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
144                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
145                    json,
146                )?),
147            },
148            _ => {
149                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
150            }
151        }
152    }
153
154    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
155        let mut messages = Vec::new();
156        let mut request_token_usage = HashMap::default();
157
158        let mut last_user_message_id = None;
159        for (ix, msg) in thread.messages.into_iter().enumerate() {
160            let message = match msg.role {
161                language_model::Role::User => {
162                    let mut content = Vec::new();
163
164                    // Convert segments to content
165                    for segment in msg.segments {
166                        match segment {
167                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
168                                content.push(UserMessageContent::Text(text));
169                            }
170                            crate::legacy_thread::SerializedMessageSegment::Thinking {
171                                text,
172                                ..
173                            } => {
174                                // User messages don't have thinking segments, but handle gracefully
175                                content.push(UserMessageContent::Text(text));
176                            }
177                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
178                                ..
179                            } => {
180                                // User messages don't have redacted thinking, skip.
181                            }
182                        }
183                    }
184
185                    // If no content was added, add context as text if available
186                    if content.is_empty() && !msg.context.is_empty() {
187                        content.push(UserMessageContent::Text(msg.context));
188                    }
189
190                    let id = UserMessageId::new();
191                    last_user_message_id = Some(id.clone());
192
193                    crate::Message::User(UserMessage {
194                        // MessageId from old format can't be meaningfully converted, so generate a new one
195                        id,
196                        content,
197                    })
198                }
199                language_model::Role::Assistant => {
200                    let mut content = Vec::new();
201
202                    // Convert segments to content
203                    for segment in msg.segments {
204                        match segment {
205                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
206                                content.push(AgentMessageContent::Text(text));
207                            }
208                            crate::legacy_thread::SerializedMessageSegment::Thinking {
209                                text,
210                                signature,
211                            } => {
212                                content.push(AgentMessageContent::Thinking { text, signature });
213                            }
214                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
215                                data,
216                            } => {
217                                content.push(AgentMessageContent::RedactedThinking(data));
218                            }
219                        }
220                    }
221
222                    // Convert tool uses
223                    let mut tool_names_by_id = HashMap::default();
224                    for tool_use in msg.tool_uses {
225                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
226                        content.push(AgentMessageContent::ToolUse(
227                            language_model::LanguageModelToolUse {
228                                id: tool_use.id,
229                                name: tool_use.name.into(),
230                                raw_input: serde_json::to_string(&tool_use.input)
231                                    .unwrap_or_default(),
232                                input: tool_use.input,
233                                is_input_complete: true,
234                                thought_signature: None,
235                            },
236                        ));
237                    }
238
239                    // Convert tool results
240                    let mut tool_results = IndexMap::default();
241                    for tool_result in msg.tool_results {
242                        let name = tool_names_by_id
243                            .remove(&tool_result.tool_use_id)
244                            .unwrap_or_else(|| SharedString::from("unknown"));
245                        tool_results.insert(
246                            tool_result.tool_use_id.clone(),
247                            language_model::LanguageModelToolResult {
248                                tool_use_id: tool_result.tool_use_id,
249                                tool_name: name.into(),
250                                is_error: tool_result.is_error,
251                                content: tool_result.content,
252                                output: tool_result.output,
253                            },
254                        );
255                    }
256
257                    if let Some(last_user_message_id) = &last_user_message_id
258                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
259                    {
260                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
261                    }
262
263                    crate::Message::Agent(AgentMessage {
264                        content,
265                        tool_results,
266                        reasoning_details: None,
267                    })
268                }
269                language_model::Role::System => {
270                    // Skip system messages as they're not supported in the new format
271                    continue;
272                }
273            };
274
275            messages.push(message);
276        }
277
278        Ok(Self {
279            title: thread.summary,
280            messages,
281            updated_at: thread.updated_at,
282            detailed_summary: match thread.detailed_summary_state {
283                crate::legacy_thread::DetailedSummaryState::NotGenerated
284                | crate::legacy_thread::DetailedSummaryState::Generating => None,
285                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
286            },
287            initial_project_snapshot: thread.initial_project_snapshot,
288            cumulative_token_usage: thread.cumulative_token_usage,
289            request_token_usage,
290            model: thread.model,
291            profile: thread.profile,
292            imported: false,
293            subagent_context: None,
294            speed: None,
295            thinking_enabled: false,
296            thinking_effort: None,
297            draft_prompt: None,
298            ui_scroll_position: None,
299        })
300    }
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
304pub enum DataType {
305    #[serde(rename = "json")]
306    Json,
307    #[serde(rename = "zstd")]
308    Zstd,
309}
310
311impl Bind for DataType {
312    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
313        let value = match self {
314            DataType::Json => "json",
315            DataType::Zstd => "zstd",
316        };
317        value.bind(statement, start_index)
318    }
319}
320
321impl Column for DataType {
322    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
323        let (value, next_index) = String::column(statement, start_index)?;
324        let data_type = match value.as_str() {
325            "json" => DataType::Json,
326            "zstd" => DataType::Zstd,
327            _ => anyhow::bail!("Unknown data type: {}", value),
328        };
329        Ok((data_type, next_index))
330    }
331}
332
333pub(crate) struct ThreadsDatabase {
334    executor: BackgroundExecutor,
335    connection: Arc<Mutex<Connection>>,
336}
337
338struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
339
340impl Global for GlobalThreadsDatabase {}
341
342impl ThreadsDatabase {
343    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
344        if cx.has_global::<GlobalThreadsDatabase>() {
345            return cx.global::<GlobalThreadsDatabase>().0.clone();
346        }
347        let executor = cx.background_executor().clone();
348        let task = executor
349            .spawn({
350                let executor = executor.clone();
351                async move {
352                    match ThreadsDatabase::new(executor) {
353                        Ok(db) => Ok(Arc::new(db)),
354                        Err(err) => Err(Arc::new(err)),
355                    }
356                }
357            })
358            .shared();
359
360        cx.set_global(GlobalThreadsDatabase(task.clone()));
361        task
362    }
363
364    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
365        let connection = if *ZED_STATELESS {
366            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
367        } else if cfg!(any(feature = "test-support", test)) {
368            // rust stores the name of the test on the current thread.
369            // We use this to automatically create a database that will
370            // be shared within the test (for the test_retrieve_old_thread)
371            // but not with concurrent tests.
372            let thread = std::thread::current();
373            let test_name = thread.name();
374            Connection::open_memory(Some(&format!(
375                "THREAD_FALLBACK_{}",
376                test_name.unwrap_or_default()
377            )))
378        } else {
379            let threads_dir = paths::data_dir().join("threads");
380            std::fs::create_dir_all(&threads_dir)?;
381            let sqlite_path = threads_dir.join("threads.db");
382            Connection::open_file(&sqlite_path.to_string_lossy())
383        };
384
385        connection.exec(indoc! {"
386            CREATE TABLE IF NOT EXISTS threads (
387                id TEXT PRIMARY KEY,
388                summary TEXT NOT NULL,
389                updated_at TEXT NOT NULL,
390                data_type TEXT NOT NULL,
391                data BLOB NOT NULL
392            )
393        "})?()
394        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
395
396        if let Ok(mut s) = connection.exec(indoc! {"
397            ALTER TABLE threads ADD COLUMN parent_id TEXT
398        "})
399        {
400            s().ok();
401        }
402
403        if let Ok(mut s) = connection.exec(indoc! {"
404            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
405            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
406        "})
407        {
408            s().ok();
409        }
410
411        let db = Self {
412            executor,
413            connection: Arc::new(Mutex::new(connection)),
414        };
415
416        Ok(db)
417    }
418
419    fn save_thread_sync(
420        connection: &Arc<Mutex<Connection>>,
421        id: acp::SessionId,
422        thread: DbThread,
423        folder_paths: &PathList,
424    ) -> Result<()> {
425        const COMPRESSION_LEVEL: i32 = 3;
426
427        #[derive(Serialize)]
428        struct SerializedThread {
429            #[serde(flatten)]
430            thread: DbThread,
431            version: &'static str,
432        }
433
434        let title = thread.title.to_string();
435        let updated_at = thread.updated_at.to_rfc3339();
436        let parent_id = thread
437            .subagent_context
438            .as_ref()
439            .map(|ctx| ctx.parent_thread_id.0.clone());
440        let serialized_folder_paths = folder_paths.serialize();
441        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
442            if folder_paths.is_empty() {
443                (None, None)
444            } else {
445                (
446                    Some(serialized_folder_paths.paths),
447                    Some(serialized_folder_paths.order),
448                )
449            };
450        let json_data = serde_json::to_string(&SerializedThread {
451            thread,
452            version: DbThread::VERSION,
453        })?;
454
455        let connection = connection.lock();
456
457        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
458        let data_type = DataType::Zstd;
459        let data = compressed;
460
461        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
462            INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
463        "})?;
464
465        insert((
466            id.0,
467            parent_id,
468            folder_paths_str,
469            folder_paths_order_str,
470            title,
471            updated_at,
472            data_type,
473            data,
474        ))?;
475
476        Ok(())
477    }
478
479    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
480        let connection = self.connection.clone();
481
482        self.executor.spawn(async move {
483            let connection = connection.lock();
484
485            let mut select = connection
486                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
487                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
488            "})?;
489
490            let rows = select(())?;
491            let mut threads = Vec::new();
492
493            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
494                let folder_paths = folder_paths
495                    .map(|paths| {
496                        PathList::deserialize(&util::path_list::SerializedPathList {
497                            paths,
498                            order: folder_paths_order.unwrap_or_default(),
499                        })
500                    })
501                    .unwrap_or_default();
502                threads.push(DbThreadMetadata {
503                    id: acp::SessionId::new(id),
504                    parent_session_id: parent_id.map(acp::SessionId::new),
505                    title: summary.into(),
506                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
507                    folder_paths,
508                });
509            }
510
511            Ok(threads)
512        })
513    }
514
515    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
516        let connection = self.connection.clone();
517
518        self.executor.spawn(async move {
519            let connection = connection.lock();
520            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
521                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
522            "})?;
523
524            let rows = select(id.0)?;
525            if let Some((data_type, data)) = rows.into_iter().next() {
526                let json_data = match data_type {
527                    DataType::Zstd => {
528                        let decompressed = zstd::decode_all(&data[..])?;
529                        String::from_utf8(decompressed)?
530                    }
531                    DataType::Json => String::from_utf8(data)?,
532                };
533                let thread = DbThread::from_json(json_data.as_bytes())?;
534                Ok(Some(thread))
535            } else {
536                Ok(None)
537            }
538        })
539    }
540
541    pub fn save_thread(
542        &self,
543        id: acp::SessionId,
544        thread: DbThread,
545        folder_paths: PathList,
546    ) -> Task<Result<()>> {
547        let connection = self.connection.clone();
548
549        self.executor
550            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
551    }
552
553    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
554        let connection = self.connection.clone();
555
556        self.executor.spawn(async move {
557            let connection = connection.lock();
558
559            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
560                DELETE FROM threads WHERE id = ?
561            "})?;
562
563            delete(id.0)?;
564
565            Ok(())
566        })
567    }
568
569    pub fn delete_threads(&self) -> Task<Result<()>> {
570        let connection = self.connection.clone();
571
572        self.executor.spawn(async move {
573            let connection = connection.lock();
574
575            let mut delete = connection.exec_bound::<()>(indoc! {"
576                DELETE FROM threads
577            "})?;
578
579            delete(())?;
580
581            Ok(())
582        })
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use chrono::{DateTime, TimeZone, Utc};
590    use collections::HashMap;
591    use gpui::TestAppContext;
592    use std::sync::Arc;
593
594    #[test]
595    fn test_shared_thread_roundtrip() {
596        let original = SharedThread {
597            title: "Test Thread".into(),
598            messages: vec![],
599            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
600            model: None,
601            version: SharedThread::VERSION.to_string(),
602        };
603
604        let bytes = original.to_bytes().expect("Failed to serialize");
605        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
606
607        assert_eq!(restored.title, original.title);
608        assert_eq!(restored.version, original.version);
609        assert_eq!(restored.updated_at, original.updated_at);
610    }
611
612    #[test]
613    fn test_imported_flag_defaults_to_false() {
614        // Simulate deserializing a thread without the imported field (backwards compatibility).
615        let json = r#"{
616            "title": "Old Thread",
617            "messages": [],
618            "updated_at": "2024-01-01T00:00:00Z"
619        }"#;
620
621        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
622
623        assert!(
624            !db_thread.imported,
625            "Legacy threads without imported field should default to false"
626        );
627    }
628
629    fn session_id(value: &str) -> acp::SessionId {
630        acp::SessionId::new(Arc::<str>::from(value))
631    }
632
633    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
634        DbThread {
635            title: title.to_string().into(),
636            messages: Vec::new(),
637            updated_at,
638            detailed_summary: None,
639            initial_project_snapshot: None,
640            cumulative_token_usage: Default::default(),
641            request_token_usage: HashMap::default(),
642            model: None,
643            profile: None,
644            imported: false,
645            subagent_context: None,
646            speed: None,
647            thinking_enabled: false,
648            thinking_effort: None,
649            draft_prompt: None,
650            ui_scroll_position: None,
651        }
652    }
653
654    #[gpui::test]
655    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
656        let database = ThreadsDatabase::new(cx.executor()).unwrap();
657
658        let older_id = session_id("thread-a");
659        let newer_id = session_id("thread-b");
660
661        let older_thread = make_thread(
662            "Thread A",
663            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
664        );
665        let newer_thread = make_thread(
666            "Thread B",
667            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
668        );
669
670        database
671            .save_thread(older_id.clone(), older_thread, PathList::default())
672            .await
673            .unwrap();
674        database
675            .save_thread(newer_id.clone(), newer_thread, PathList::default())
676            .await
677            .unwrap();
678
679        let entries = database.list_threads().await.unwrap();
680        assert_eq!(entries.len(), 2);
681        assert_eq!(entries[0].id, newer_id);
682        assert_eq!(entries[1].id, older_id);
683    }
684
685    #[gpui::test]
686    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
687        let database = ThreadsDatabase::new(cx.executor()).unwrap();
688
689        let thread_id = session_id("thread-a");
690        let original_thread = make_thread(
691            "Thread A",
692            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
693        );
694        let updated_thread = make_thread(
695            "Thread B",
696            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
697        );
698
699        database
700            .save_thread(thread_id.clone(), original_thread, PathList::default())
701            .await
702            .unwrap();
703        database
704            .save_thread(thread_id.clone(), updated_thread, PathList::default())
705            .await
706            .unwrap();
707
708        let entries = database.list_threads().await.unwrap();
709        assert_eq!(entries.len(), 1);
710        assert_eq!(entries[0].id, thread_id);
711        assert_eq!(entries[0].title.as_ref(), "Thread B");
712        assert_eq!(
713            entries[0].updated_at,
714            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
715        );
716    }
717
718    #[test]
719    fn test_subagent_context_defaults_to_none() {
720        let json = r#"{
721            "title": "Old Thread",
722            "messages": [],
723            "updated_at": "2024-01-01T00:00:00Z"
724        }"#;
725
726        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
727
728        assert!(
729            db_thread.subagent_context.is_none(),
730            "Legacy threads without subagent_context should default to None"
731        );
732    }
733
734    #[test]
735    fn test_draft_prompt_defaults_to_none() {
736        let json = r#"{
737            "title": "Old Thread",
738            "messages": [],
739            "updated_at": "2024-01-01T00:00:00Z"
740        }"#;
741
742        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
743
744        assert!(
745            db_thread.draft_prompt.is_none(),
746            "Legacy threads without draft_prompt field should default to None"
747        );
748    }
749
750    #[gpui::test]
751    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
752        let database = ThreadsDatabase::new(cx.executor()).unwrap();
753
754        let parent_id = session_id("parent-thread");
755        let child_id = session_id("child-thread");
756
757        let mut child_thread = make_thread(
758            "Subagent Thread",
759            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
760        );
761        child_thread.subagent_context = Some(crate::SubagentContext {
762            parent_thread_id: parent_id.clone(),
763            depth: 2,
764        });
765
766        database
767            .save_thread(child_id.clone(), child_thread, PathList::default())
768            .await
769            .unwrap();
770
771        let loaded = database
772            .load_thread(child_id)
773            .await
774            .unwrap()
775            .expect("thread should exist");
776
777        let context = loaded
778            .subagent_context
779            .expect("subagent_context should be restored");
780        assert_eq!(context.parent_thread_id, parent_id);
781        assert_eq!(context.depth, 2);
782    }
783
784    #[gpui::test]
785    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
786        let database = ThreadsDatabase::new(cx.executor()).unwrap();
787
788        let thread_id = session_id("regular-thread");
789        let thread = make_thread(
790            "Regular Thread",
791            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
792        );
793
794        database
795            .save_thread(thread_id.clone(), thread, PathList::default())
796            .await
797            .unwrap();
798
799        let loaded = database
800            .load_thread(thread_id)
801            .await
802            .unwrap()
803            .expect("thread should exist");
804
805        assert!(
806            loaded.subagent_context.is_none(),
807            "Regular threads should have no subagent_context"
808        );
809    }
810
811    #[gpui::test]
812    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
813        let database = ThreadsDatabase::new(cx.executor()).unwrap();
814
815        let thread_id = session_id("folder-thread");
816        let thread = make_thread(
817            "Folder Thread",
818            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
819        );
820
821        let folder_paths = PathList::new(&[
822            std::path::PathBuf::from("/home/user/project-a"),
823            std::path::PathBuf::from("/home/user/project-b"),
824        ]);
825
826        database
827            .save_thread(thread_id.clone(), thread, folder_paths.clone())
828            .await
829            .unwrap();
830
831        let threads = database.list_threads().await.unwrap();
832        assert_eq!(threads.len(), 1);
833        assert_eq!(threads[0].folder_paths, folder_paths);
834    }
835
836    #[gpui::test]
837    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
838        let database = ThreadsDatabase::new(cx.executor()).unwrap();
839
840        let thread_id = session_id("no-folder-thread");
841        let thread = make_thread(
842            "No Folder Thread",
843            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
844        );
845
846        database
847            .save_thread(thread_id.clone(), thread, PathList::default())
848            .await
849            .unwrap();
850
851        let threads = database.list_threads().await.unwrap();
852        assert_eq!(threads.len(), 1);
853        assert!(threads[0].folder_paths.is_empty());
854    }
855
856    #[test]
857    fn test_scroll_position_defaults_to_none() {
858        let json = r#"{
859            "title": "Old Thread",
860            "messages": [],
861            "updated_at": "2024-01-01T00:00:00Z"
862        }"#;
863
864        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
865
866        assert!(
867            db_thread.ui_scroll_position.is_none(),
868            "Legacy threads without scroll_position field should default to None"
869        );
870    }
871
872    #[gpui::test]
873    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
874        let database = ThreadsDatabase::new(cx.executor()).unwrap();
875
876        let thread_id = session_id("thread-with-scroll");
877
878        let mut thread = make_thread(
879            "Thread With Scroll",
880            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
881        );
882        thread.ui_scroll_position = Some(SerializedScrollPosition {
883            item_ix: 42,
884            offset_in_item: 13.5,
885        });
886
887        database
888            .save_thread(thread_id.clone(), thread, PathList::default())
889            .await
890            .unwrap();
891
892        let loaded = database
893            .load_thread(thread_id)
894            .await
895            .unwrap()
896            .expect("thread should exist");
897
898        let scroll = loaded
899            .ui_scroll_position
900            .expect("scroll_position should be restored");
901        assert_eq!(scroll.item_ix, 42);
902        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
903    }
904}