db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    pub title: SharedString,
 33    pub updated_at: DateTime<Utc>,
 34    pub created_at: Option<DateTime<Utc>>,
 35    /// The workspace folder paths this thread was created against, sorted
 36    /// lexicographically. Used for grouping threads by project in the sidebar.
 37    pub folder_paths: PathList,
 38}
 39
 40impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
 41    fn from(meta: &DbThreadMetadata) -> Self {
 42        Self {
 43            session_id: meta.id.clone(),
 44            work_dirs: Some(meta.folder_paths.clone()),
 45            title: Some(meta.title.clone()),
 46            updated_at: Some(meta.updated_at),
 47            created_at: meta.created_at,
 48            meta: None,
 49        }
 50    }
 51}
 52
 53#[derive(Debug, Serialize, Deserialize)]
 54pub struct DbThread {
 55    pub title: SharedString,
 56    pub messages: Vec<DbMessage>,
 57    pub updated_at: DateTime<Utc>,
 58    #[serde(default)]
 59    pub detailed_summary: Option<SharedString>,
 60    #[serde(default)]
 61    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 62    #[serde(default)]
 63    pub cumulative_token_usage: language_model::TokenUsage,
 64    #[serde(default)]
 65    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 66    #[serde(default)]
 67    pub model: Option<DbLanguageModel>,
 68    #[serde(default)]
 69    pub profile: Option<AgentProfileId>,
 70    #[serde(default)]
 71    pub imported: bool,
 72    #[serde(default)]
 73    pub subagent_context: Option<crate::SubagentContext>,
 74    #[serde(default)]
 75    pub speed: Option<Speed>,
 76    #[serde(default)]
 77    pub thinking_enabled: bool,
 78    #[serde(default)]
 79    pub thinking_effort: Option<String>,
 80    #[serde(default)]
 81    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 82    #[serde(default)]
 83    pub ui_scroll_position: Option<SerializedScrollPosition>,
 84}
 85
 86#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 87pub struct SerializedScrollPosition {
 88    pub item_ix: usize,
 89    pub offset_in_item: f32,
 90}
 91
 92#[derive(Debug, Clone, Serialize, Deserialize)]
 93pub struct SharedThread {
 94    pub title: SharedString,
 95    pub messages: Vec<DbMessage>,
 96    pub updated_at: DateTime<Utc>,
 97    #[serde(default)]
 98    pub model: Option<DbLanguageModel>,
 99    pub version: String,
100}
101
102impl SharedThread {
103    pub const VERSION: &'static str = "1.0.0";
104
105    pub fn from_db_thread(thread: &DbThread) -> Self {
106        Self {
107            title: thread.title.clone(),
108            messages: thread.messages.clone(),
109            updated_at: thread.updated_at,
110            model: thread.model.clone(),
111            version: Self::VERSION.to_string(),
112        }
113    }
114
115    pub fn to_db_thread(self) -> DbThread {
116        DbThread {
117            title: format!("🔗 {}", self.title).into(),
118            messages: self.messages,
119            updated_at: self.updated_at,
120            detailed_summary: None,
121            initial_project_snapshot: None,
122            cumulative_token_usage: Default::default(),
123            request_token_usage: Default::default(),
124            model: self.model,
125            profile: None,
126            imported: true,
127            subagent_context: None,
128            speed: None,
129            thinking_enabled: false,
130            thinking_effort: None,
131            draft_prompt: None,
132            ui_scroll_position: None,
133        }
134    }
135
136    pub fn to_bytes(&self) -> Result<Vec<u8>> {
137        const COMPRESSION_LEVEL: i32 = 3;
138        let json = serde_json::to_vec(self)?;
139        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
140        Ok(compressed)
141    }
142
143    pub fn from_bytes(data: &[u8]) -> Result<Self> {
144        let decompressed = zstd::decode_all(data)?;
145        Ok(serde_json::from_slice(&decompressed)?)
146    }
147}
148
149impl DbThread {
150    pub const VERSION: &'static str = "0.3.0";
151
152    pub fn from_json(json: &[u8]) -> Result<Self> {
153        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
154        match saved_thread_json.get("version") {
155            Some(serde_json::Value::String(version)) => match version.as_str() {
156                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
157                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
158                    json,
159                )?),
160            },
161            _ => {
162                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
163            }
164        }
165    }
166
167    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
168        let mut messages = Vec::new();
169        let mut request_token_usage = HashMap::default();
170
171        let mut last_user_message_id = None;
172        for (ix, msg) in thread.messages.into_iter().enumerate() {
173            let message = match msg.role {
174                language_model::Role::User => {
175                    let mut content = Vec::new();
176
177                    // Convert segments to content
178                    for segment in msg.segments {
179                        match segment {
180                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
181                                content.push(UserMessageContent::Text(text));
182                            }
183                            crate::legacy_thread::SerializedMessageSegment::Thinking {
184                                text,
185                                ..
186                            } => {
187                                // User messages don't have thinking segments, but handle gracefully
188                                content.push(UserMessageContent::Text(text));
189                            }
190                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
191                                ..
192                            } => {
193                                // User messages don't have redacted thinking, skip.
194                            }
195                        }
196                    }
197
198                    // If no content was added, add context as text if available
199                    if content.is_empty() && !msg.context.is_empty() {
200                        content.push(UserMessageContent::Text(msg.context));
201                    }
202
203                    let id = UserMessageId::new();
204                    last_user_message_id = Some(id.clone());
205
206                    crate::Message::User(UserMessage {
207                        // MessageId from old format can't be meaningfully converted, so generate a new one
208                        id,
209                        content,
210                    })
211                }
212                language_model::Role::Assistant => {
213                    let mut content = Vec::new();
214
215                    // Convert segments to content
216                    for segment in msg.segments {
217                        match segment {
218                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
219                                content.push(AgentMessageContent::Text(text));
220                            }
221                            crate::legacy_thread::SerializedMessageSegment::Thinking {
222                                text,
223                                signature,
224                            } => {
225                                content.push(AgentMessageContent::Thinking { text, signature });
226                            }
227                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
228                                data,
229                            } => {
230                                content.push(AgentMessageContent::RedactedThinking(data));
231                            }
232                        }
233                    }
234
235                    // Convert tool uses
236                    let mut tool_names_by_id = HashMap::default();
237                    for tool_use in msg.tool_uses {
238                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
239                        content.push(AgentMessageContent::ToolUse(
240                            language_model::LanguageModelToolUse {
241                                id: tool_use.id,
242                                name: tool_use.name.into(),
243                                raw_input: serde_json::to_string(&tool_use.input)
244                                    .unwrap_or_default(),
245                                input: tool_use.input,
246                                is_input_complete: true,
247                                thought_signature: None,
248                            },
249                        ));
250                    }
251
252                    // Convert tool results
253                    let mut tool_results = IndexMap::default();
254                    for tool_result in msg.tool_results {
255                        let name = tool_names_by_id
256                            .remove(&tool_result.tool_use_id)
257                            .unwrap_or_else(|| SharedString::from("unknown"));
258                        tool_results.insert(
259                            tool_result.tool_use_id.clone(),
260                            language_model::LanguageModelToolResult {
261                                tool_use_id: tool_result.tool_use_id,
262                                tool_name: name.into(),
263                                is_error: tool_result.is_error,
264                                content: tool_result.content,
265                                output: tool_result.output,
266                            },
267                        );
268                    }
269
270                    if let Some(last_user_message_id) = &last_user_message_id
271                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
272                    {
273                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
274                    }
275
276                    crate::Message::Agent(AgentMessage {
277                        content,
278                        tool_results,
279                        reasoning_details: None,
280                    })
281                }
282                language_model::Role::System => {
283                    // Skip system messages as they're not supported in the new format
284                    continue;
285                }
286            };
287
288            messages.push(message);
289        }
290
291        Ok(Self {
292            title: thread.summary,
293            messages,
294            updated_at: thread.updated_at,
295            detailed_summary: match thread.detailed_summary_state {
296                crate::legacy_thread::DetailedSummaryState::NotGenerated
297                | crate::legacy_thread::DetailedSummaryState::Generating => None,
298                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
299            },
300            initial_project_snapshot: thread.initial_project_snapshot,
301            cumulative_token_usage: thread.cumulative_token_usage,
302            request_token_usage,
303            model: thread.model,
304            profile: thread.profile,
305            imported: false,
306            subagent_context: None,
307            speed: None,
308            thinking_enabled: false,
309            thinking_effort: None,
310            draft_prompt: None,
311            ui_scroll_position: None,
312        })
313    }
314}
315
316#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
317pub enum DataType {
318    #[serde(rename = "json")]
319    Json,
320    #[serde(rename = "zstd")]
321    Zstd,
322}
323
324impl Bind for DataType {
325    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
326        let value = match self {
327            DataType::Json => "json",
328            DataType::Zstd => "zstd",
329        };
330        value.bind(statement, start_index)
331    }
332}
333
334impl Column for DataType {
335    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
336        let (value, next_index) = String::column(statement, start_index)?;
337        let data_type = match value.as_str() {
338            "json" => DataType::Json,
339            "zstd" => DataType::Zstd,
340            _ => anyhow::bail!("Unknown data type: {}", value),
341        };
342        Ok((data_type, next_index))
343    }
344}
345
346pub(crate) struct ThreadsDatabase {
347    executor: BackgroundExecutor,
348    connection: Arc<Mutex<Connection>>,
349}
350
351struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
352
353impl Global for GlobalThreadsDatabase {}
354
355impl ThreadsDatabase {
356    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
357        if cx.has_global::<GlobalThreadsDatabase>() {
358            return cx.global::<GlobalThreadsDatabase>().0.clone();
359        }
360        let executor = cx.background_executor().clone();
361        let task = executor
362            .spawn({
363                let executor = executor.clone();
364                async move {
365                    match ThreadsDatabase::new(executor) {
366                        Ok(db) => Ok(Arc::new(db)),
367                        Err(err) => Err(Arc::new(err)),
368                    }
369                }
370            })
371            .shared();
372
373        cx.set_global(GlobalThreadsDatabase(task.clone()));
374        task
375    }
376
377    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
378        let connection = if *ZED_STATELESS {
379            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
380        } else if cfg!(any(feature = "test-support", test)) {
381            // rust stores the name of the test on the current thread.
382            // We use this to automatically create a database that will
383            // be shared within the test (for the test_retrieve_old_thread)
384            // but not with concurrent tests.
385            let thread = std::thread::current();
386            let test_name = thread.name();
387            Connection::open_memory(Some(&format!(
388                "THREAD_FALLBACK_{}",
389                test_name.unwrap_or_default()
390            )))
391        } else {
392            let threads_dir = paths::data_dir().join("threads");
393            std::fs::create_dir_all(&threads_dir)?;
394            let sqlite_path = threads_dir.join("threads.db");
395            Connection::open_file(&sqlite_path.to_string_lossy())
396        };
397
398        connection.exec(indoc! {"
399            CREATE TABLE IF NOT EXISTS threads (
400                id TEXT PRIMARY KEY,
401                summary TEXT NOT NULL,
402                updated_at TEXT NOT NULL,
403                data_type TEXT NOT NULL,
404                data BLOB NOT NULL
405            )
406        "})?()
407        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
408
409        if let Ok(mut s) = connection.exec(indoc! {"
410            ALTER TABLE threads ADD COLUMN parent_id TEXT
411        "})
412        {
413            s().ok();
414        }
415
416        if let Ok(mut s) = connection.exec(indoc! {"
417            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
418            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
419        "})
420        {
421            s().ok();
422        }
423
424        if let Ok(mut s) = connection.exec(indoc! {"
425            ALTER TABLE threads ADD COLUMN created_at TEXT;
426        "})
427        {
428            if s().is_ok() {
429                connection.exec(indoc! {"
430                    UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
431                "})?()?;
432            }
433        }
434
435        let db = Self {
436            executor,
437            connection: Arc::new(Mutex::new(connection)),
438        };
439
440        Ok(db)
441    }
442
443    fn save_thread_sync(
444        connection: &Arc<Mutex<Connection>>,
445        id: acp::SessionId,
446        thread: DbThread,
447        folder_paths: &PathList,
448    ) -> Result<()> {
449        const COMPRESSION_LEVEL: i32 = 3;
450
451        #[derive(Serialize)]
452        struct SerializedThread {
453            #[serde(flatten)]
454            thread: DbThread,
455            version: &'static str,
456        }
457
458        let title = thread.title.to_string();
459        let updated_at = thread.updated_at.to_rfc3339();
460        let parent_id = thread
461            .subagent_context
462            .as_ref()
463            .map(|ctx| ctx.parent_thread_id.0.clone());
464        let serialized_folder_paths = folder_paths.serialize();
465        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
466            if folder_paths.is_empty() {
467                (None, None)
468            } else {
469                (
470                    Some(serialized_folder_paths.paths),
471                    Some(serialized_folder_paths.order),
472                )
473            };
474        let json_data = serde_json::to_string(&SerializedThread {
475            thread,
476            version: DbThread::VERSION,
477        })?;
478
479        let connection = connection.lock();
480
481        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
482        let data_type = DataType::Zstd;
483        let data = compressed;
484
485        // Use the thread's updated_at as created_at for new threads.
486        // This ensures the creation time reflects when the thread was conceptually
487        // created, not when it was saved to the database.
488        let created_at = updated_at.clone();
489
490        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
491            INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
492            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
493            ON CONFLICT(id) DO UPDATE SET
494                parent_id = excluded.parent_id,
495                folder_paths = excluded.folder_paths,
496                folder_paths_order = excluded.folder_paths_order,
497                summary = excluded.summary,
498                updated_at = excluded.updated_at,
499                data_type = excluded.data_type,
500                data = excluded.data
501        "})?;
502
503        insert((
504            id.0,
505            parent_id,
506            folder_paths_str,
507            folder_paths_order_str,
508            title,
509            updated_at,
510            data_type,
511            data,
512            created_at,
513        ))?;
514
515        Ok(())
516    }
517
518    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
519        let connection = self.connection.clone();
520
521        self.executor.spawn(async move {
522            let connection = connection.lock();
523
524            let mut select = connection
525                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
526                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
527            "})?;
528
529            let rows = select(())?;
530            let mut threads = Vec::new();
531
532            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
533                let folder_paths = folder_paths
534                    .map(|paths| {
535                        PathList::deserialize(&util::path_list::SerializedPathList {
536                            paths,
537                            order: folder_paths_order.unwrap_or_default(),
538                        })
539                    })
540                    .unwrap_or_default();
541                let created_at = created_at
542                    .as_deref()
543                    .map(DateTime::parse_from_rfc3339)
544                    .transpose()?
545                    .map(|dt| dt.with_timezone(&Utc));
546
547                threads.push(DbThreadMetadata {
548                    id: acp::SessionId::new(id),
549                    parent_session_id: parent_id.map(acp::SessionId::new),
550                    title: summary.into(),
551                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
552                    created_at,
553                    folder_paths,
554                });
555            }
556
557            Ok(threads)
558        })
559    }
560
561    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
562        let connection = self.connection.clone();
563
564        self.executor.spawn(async move {
565            let connection = connection.lock();
566            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
567                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
568            "})?;
569
570            let rows = select(id.0)?;
571            if let Some((data_type, data)) = rows.into_iter().next() {
572                let json_data = match data_type {
573                    DataType::Zstd => {
574                        let decompressed = zstd::decode_all(&data[..])?;
575                        String::from_utf8(decompressed)?
576                    }
577                    DataType::Json => String::from_utf8(data)?,
578                };
579                let thread = DbThread::from_json(json_data.as_bytes())?;
580                Ok(Some(thread))
581            } else {
582                Ok(None)
583            }
584        })
585    }
586
587    pub fn save_thread(
588        &self,
589        id: acp::SessionId,
590        thread: DbThread,
591        folder_paths: PathList,
592    ) -> Task<Result<()>> {
593        let connection = self.connection.clone();
594
595        self.executor
596            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
597    }
598
599    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
600        let connection = self.connection.clone();
601
602        self.executor.spawn(async move {
603            let connection = connection.lock();
604
605            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
606                DELETE FROM threads WHERE id = ?
607            "})?;
608
609            delete(id.0)?;
610
611            Ok(())
612        })
613    }
614
615    pub fn delete_threads(&self) -> Task<Result<()>> {
616        let connection = self.connection.clone();
617
618        self.executor.spawn(async move {
619            let connection = connection.lock();
620
621            let mut delete = connection.exec_bound::<()>(indoc! {"
622                DELETE FROM threads
623            "})?;
624
625            delete(())?;
626
627            Ok(())
628        })
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use chrono::{DateTime, TimeZone, Utc};
636    use collections::HashMap;
637    use gpui::TestAppContext;
638    use std::sync::Arc;
639
640    #[test]
641    fn test_shared_thread_roundtrip() {
642        let original = SharedThread {
643            title: "Test Thread".into(),
644            messages: vec![],
645            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
646            model: None,
647            version: SharedThread::VERSION.to_string(),
648        };
649
650        let bytes = original.to_bytes().expect("Failed to serialize");
651        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
652
653        assert_eq!(restored.title, original.title);
654        assert_eq!(restored.version, original.version);
655        assert_eq!(restored.updated_at, original.updated_at);
656    }
657
658    #[test]
659    fn test_imported_flag_defaults_to_false() {
660        // Simulate deserializing a thread without the imported field (backwards compatibility).
661        let json = r#"{
662            "title": "Old Thread",
663            "messages": [],
664            "updated_at": "2024-01-01T00:00:00Z"
665        }"#;
666
667        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
668
669        assert!(
670            !db_thread.imported,
671            "Legacy threads without imported field should default to false"
672        );
673    }
674
675    fn session_id(value: &str) -> acp::SessionId {
676        acp::SessionId::new(Arc::<str>::from(value))
677    }
678
679    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
680        DbThread {
681            title: title.to_string().into(),
682            messages: Vec::new(),
683            updated_at,
684            detailed_summary: None,
685            initial_project_snapshot: None,
686            cumulative_token_usage: Default::default(),
687            request_token_usage: HashMap::default(),
688            model: None,
689            profile: None,
690            imported: false,
691            subagent_context: None,
692            speed: None,
693            thinking_enabled: false,
694            thinking_effort: None,
695            draft_prompt: None,
696            ui_scroll_position: None,
697        }
698    }
699
700    #[gpui::test]
701    async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
702        let database = ThreadsDatabase::new(cx.executor()).unwrap();
703
704        let older_id = session_id("thread-a");
705        let newer_id = session_id("thread-b");
706
707        let older_thread = make_thread(
708            "Thread A",
709            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
710        );
711        let newer_thread = make_thread(
712            "Thread B",
713            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
714        );
715
716        database
717            .save_thread(older_id.clone(), older_thread, PathList::default())
718            .await
719            .unwrap();
720        database
721            .save_thread(newer_id.clone(), newer_thread, PathList::default())
722            .await
723            .unwrap();
724
725        let entries = database.list_threads().await.unwrap();
726        assert_eq!(entries.len(), 2);
727        assert_eq!(entries[0].id, newer_id);
728        assert_eq!(entries[1].id, older_id);
729    }
730
731    #[gpui::test]
732    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
733        let database = ThreadsDatabase::new(cx.executor()).unwrap();
734
735        let thread_id = session_id("thread-a");
736        let original_thread = make_thread(
737            "Thread A",
738            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
739        );
740        let updated_thread = make_thread(
741            "Thread B",
742            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
743        );
744
745        database
746            .save_thread(thread_id.clone(), original_thread, PathList::default())
747            .await
748            .unwrap();
749        database
750            .save_thread(thread_id.clone(), updated_thread, PathList::default())
751            .await
752            .unwrap();
753
754        let entries = database.list_threads().await.unwrap();
755        assert_eq!(entries.len(), 1);
756        assert_eq!(entries[0].id, thread_id);
757        assert_eq!(entries[0].title.as_ref(), "Thread B");
758        assert_eq!(
759            entries[0].updated_at,
760            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
761        );
762        assert!(
763            entries[0].created_at.is_some(),
764            "created_at should be populated"
765        );
766    }
767
768    #[test]
769    fn test_subagent_context_defaults_to_none() {
770        let json = r#"{
771            "title": "Old Thread",
772            "messages": [],
773            "updated_at": "2024-01-01T00:00:00Z"
774        }"#;
775
776        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
777
778        assert!(
779            db_thread.subagent_context.is_none(),
780            "Legacy threads without subagent_context should default to None"
781        );
782    }
783
784    #[test]
785    fn test_draft_prompt_defaults_to_none() {
786        let json = r#"{
787            "title": "Old Thread",
788            "messages": [],
789            "updated_at": "2024-01-01T00:00:00Z"
790        }"#;
791
792        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
793
794        assert!(
795            db_thread.draft_prompt.is_none(),
796            "Legacy threads without draft_prompt field should default to None"
797        );
798    }
799
800    #[gpui::test]
801    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
802        let database = ThreadsDatabase::new(cx.executor()).unwrap();
803
804        let parent_id = session_id("parent-thread");
805        let child_id = session_id("child-thread");
806
807        let mut child_thread = make_thread(
808            "Subagent Thread",
809            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
810        );
811        child_thread.subagent_context = Some(crate::SubagentContext {
812            parent_thread_id: parent_id.clone(),
813            depth: 2,
814        });
815
816        database
817            .save_thread(child_id.clone(), child_thread, PathList::default())
818            .await
819            .unwrap();
820
821        let loaded = database
822            .load_thread(child_id)
823            .await
824            .unwrap()
825            .expect("thread should exist");
826
827        let context = loaded
828            .subagent_context
829            .expect("subagent_context should be restored");
830        assert_eq!(context.parent_thread_id, parent_id);
831        assert_eq!(context.depth, 2);
832    }
833
834    #[gpui::test]
835    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
836        let database = ThreadsDatabase::new(cx.executor()).unwrap();
837
838        let thread_id = session_id("regular-thread");
839        let thread = make_thread(
840            "Regular Thread",
841            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
842        );
843
844        database
845            .save_thread(thread_id.clone(), thread, PathList::default())
846            .await
847            .unwrap();
848
849        let loaded = database
850            .load_thread(thread_id)
851            .await
852            .unwrap()
853            .expect("thread should exist");
854
855        assert!(
856            loaded.subagent_context.is_none(),
857            "Regular threads should have no subagent_context"
858        );
859    }
860
861    #[gpui::test]
862    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
863        let database = ThreadsDatabase::new(cx.executor()).unwrap();
864
865        let thread_id = session_id("folder-thread");
866        let thread = make_thread(
867            "Folder Thread",
868            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
869        );
870
871        let folder_paths = PathList::new(&[
872            std::path::PathBuf::from("/home/user/project-a"),
873            std::path::PathBuf::from("/home/user/project-b"),
874        ]);
875
876        database
877            .save_thread(thread_id.clone(), thread, folder_paths.clone())
878            .await
879            .unwrap();
880
881        let threads = database.list_threads().await.unwrap();
882        assert_eq!(threads.len(), 1);
883    }
884
885    #[gpui::test]
886    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
887        let database = ThreadsDatabase::new(cx.executor()).unwrap();
888
889        let thread_id = session_id("no-folder-thread");
890        let thread = make_thread(
891            "No Folder Thread",
892            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
893        );
894
895        database
896            .save_thread(thread_id.clone(), thread, PathList::default())
897            .await
898            .unwrap();
899
900        let threads = database.list_threads().await.unwrap();
901        assert_eq!(threads.len(), 1);
902    }
903
904    #[test]
905    fn test_scroll_position_defaults_to_none() {
906        let json = r#"{
907            "title": "Old Thread",
908            "messages": [],
909            "updated_at": "2024-01-01T00:00:00Z"
910        }"#;
911
912        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
913
914        assert!(
915            db_thread.ui_scroll_position.is_none(),
916            "Legacy threads without scroll_position field should default to None"
917        );
918    }
919
920    #[gpui::test]
921    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
922        let database = ThreadsDatabase::new(cx.executor()).unwrap();
923
924        let thread_id = session_id("thread-with-scroll");
925
926        let mut thread = make_thread(
927            "Thread With Scroll",
928            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
929        );
930        thread.ui_scroll_position = Some(SerializedScrollPosition {
931            item_ix: 42,
932            offset_in_item: 13.5,
933        });
934
935        database
936            .save_thread(thread_id.clone(), thread, PathList::default())
937            .await
938            .unwrap();
939
940        let loaded = database
941            .load_thread(thread_id)
942            .await
943            .unwrap()
944            .expect("thread should exist");
945
946        let scroll = loaded
947            .ui_scroll_position
948            .expect("scroll_position should be restored");
949        assert_eq!(scroll.item_ix, 42);
950        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
951    }
952}