db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    pub created_at: Option<DateTime<Utc>>,
 36    /// The workspace folder paths this thread was created against, sorted
 37    /// lexicographically. Used for grouping threads by project in the sidebar.
 38    pub folder_paths: PathList,
 39}
 40
 41impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
 42    fn from(meta: &DbThreadMetadata) -> Self {
 43        Self {
 44            session_id: meta.id.clone(),
 45            cwd: None,
 46            title: Some(meta.title.clone()),
 47            updated_at: Some(meta.updated_at),
 48            created_at: meta.created_at,
 49            meta: None,
 50        }
 51    }
 52}
 53
 54#[derive(Debug, Serialize, Deserialize)]
 55pub struct DbThread {
 56    pub title: SharedString,
 57    pub messages: Vec<DbMessage>,
 58    pub updated_at: DateTime<Utc>,
 59    #[serde(default)]
 60    pub detailed_summary: Option<SharedString>,
 61    #[serde(default)]
 62    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 63    #[serde(default)]
 64    pub cumulative_token_usage: language_model::TokenUsage,
 65    #[serde(default)]
 66    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 67    #[serde(default)]
 68    pub model: Option<DbLanguageModel>,
 69    #[serde(default)]
 70    pub profile: Option<AgentProfileId>,
 71    #[serde(default)]
 72    pub imported: bool,
 73    #[serde(default)]
 74    pub subagent_context: Option<crate::SubagentContext>,
 75    #[serde(default)]
 76    pub speed: Option<Speed>,
 77    #[serde(default)]
 78    pub thinking_enabled: bool,
 79    #[serde(default)]
 80    pub thinking_effort: Option<String>,
 81    #[serde(default)]
 82    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 83    #[serde(default)]
 84    pub ui_scroll_position: Option<SerializedScrollPosition>,
 85}
 86
 87#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 88pub struct SerializedScrollPosition {
 89    pub item_ix: usize,
 90    pub offset_in_item: f32,
 91}
 92
 93#[derive(Debug, Clone, Serialize, Deserialize)]
 94pub struct SharedThread {
 95    pub title: SharedString,
 96    pub messages: Vec<DbMessage>,
 97    pub updated_at: DateTime<Utc>,
 98    #[serde(default)]
 99    pub model: Option<DbLanguageModel>,
100    pub version: String,
101}
102
103impl SharedThread {
104    pub const VERSION: &'static str = "1.0.0";
105
106    pub fn from_db_thread(thread: &DbThread) -> Self {
107        Self {
108            title: thread.title.clone(),
109            messages: thread.messages.clone(),
110            updated_at: thread.updated_at,
111            model: thread.model.clone(),
112            version: Self::VERSION.to_string(),
113        }
114    }
115
116    pub fn to_db_thread(self) -> DbThread {
117        DbThread {
118            title: format!("🔗 {}", self.title).into(),
119            messages: self.messages,
120            updated_at: self.updated_at,
121            detailed_summary: None,
122            initial_project_snapshot: None,
123            cumulative_token_usage: Default::default(),
124            request_token_usage: Default::default(),
125            model: self.model,
126            profile: None,
127            imported: true,
128            subagent_context: None,
129            speed: None,
130            thinking_enabled: false,
131            thinking_effort: None,
132            draft_prompt: None,
133            ui_scroll_position: None,
134        }
135    }
136
137    pub fn to_bytes(&self) -> Result<Vec<u8>> {
138        const COMPRESSION_LEVEL: i32 = 3;
139        let json = serde_json::to_vec(self)?;
140        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
141        Ok(compressed)
142    }
143
144    pub fn from_bytes(data: &[u8]) -> Result<Self> {
145        let decompressed = zstd::decode_all(data)?;
146        Ok(serde_json::from_slice(&decompressed)?)
147    }
148}
149
150impl DbThread {
151    pub const VERSION: &'static str = "0.3.0";
152
153    pub fn from_json(json: &[u8]) -> Result<Self> {
154        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
155        match saved_thread_json.get("version") {
156            Some(serde_json::Value::String(version)) => match version.as_str() {
157                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
158                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
159                    json,
160                )?),
161            },
162            _ => {
163                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
164            }
165        }
166    }
167
168    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
169        let mut messages = Vec::new();
170        let mut request_token_usage = HashMap::default();
171
172        let mut last_user_message_id = None;
173        for (ix, msg) in thread.messages.into_iter().enumerate() {
174            let message = match msg.role {
175                language_model::Role::User => {
176                    let mut content = Vec::new();
177
178                    // Convert segments to content
179                    for segment in msg.segments {
180                        match segment {
181                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
182                                content.push(UserMessageContent::Text(text));
183                            }
184                            crate::legacy_thread::SerializedMessageSegment::Thinking {
185                                text,
186                                ..
187                            } => {
188                                // User messages don't have thinking segments, but handle gracefully
189                                content.push(UserMessageContent::Text(text));
190                            }
191                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
192                                ..
193                            } => {
194                                // User messages don't have redacted thinking, skip.
195                            }
196                        }
197                    }
198
199                    // If no content was added, add context as text if available
200                    if content.is_empty() && !msg.context.is_empty() {
201                        content.push(UserMessageContent::Text(msg.context));
202                    }
203
204                    let id = UserMessageId::new();
205                    last_user_message_id = Some(id.clone());
206
207                    crate::Message::User(UserMessage {
208                        // MessageId from old format can't be meaningfully converted, so generate a new one
209                        id,
210                        content,
211                    })
212                }
213                language_model::Role::Assistant => {
214                    let mut content = Vec::new();
215
216                    // Convert segments to content
217                    for segment in msg.segments {
218                        match segment {
219                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
220                                content.push(AgentMessageContent::Text(text));
221                            }
222                            crate::legacy_thread::SerializedMessageSegment::Thinking {
223                                text,
224                                signature,
225                            } => {
226                                content.push(AgentMessageContent::Thinking { text, signature });
227                            }
228                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
229                                data,
230                            } => {
231                                content.push(AgentMessageContent::RedactedThinking(data));
232                            }
233                        }
234                    }
235
236                    // Convert tool uses
237                    let mut tool_names_by_id = HashMap::default();
238                    for tool_use in msg.tool_uses {
239                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
240                        content.push(AgentMessageContent::ToolUse(
241                            language_model::LanguageModelToolUse {
242                                id: tool_use.id,
243                                name: tool_use.name.into(),
244                                raw_input: serde_json::to_string(&tool_use.input)
245                                    .unwrap_or_default(),
246                                input: tool_use.input,
247                                is_input_complete: true,
248                                thought_signature: None,
249                            },
250                        ));
251                    }
252
253                    // Convert tool results
254                    let mut tool_results = IndexMap::default();
255                    for tool_result in msg.tool_results {
256                        let name = tool_names_by_id
257                            .remove(&tool_result.tool_use_id)
258                            .unwrap_or_else(|| SharedString::from("unknown"));
259                        tool_results.insert(
260                            tool_result.tool_use_id.clone(),
261                            language_model::LanguageModelToolResult {
262                                tool_use_id: tool_result.tool_use_id,
263                                tool_name: name.into(),
264                                is_error: tool_result.is_error,
265                                content: tool_result.content,
266                                output: tool_result.output,
267                            },
268                        );
269                    }
270
271                    if let Some(last_user_message_id) = &last_user_message_id
272                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
273                    {
274                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
275                    }
276
277                    crate::Message::Agent(AgentMessage {
278                        content,
279                        tool_results,
280                        reasoning_details: None,
281                    })
282                }
283                language_model::Role::System => {
284                    // Skip system messages as they're not supported in the new format
285                    continue;
286                }
287            };
288
289            messages.push(message);
290        }
291
292        Ok(Self {
293            title: thread.summary,
294            messages,
295            updated_at: thread.updated_at,
296            detailed_summary: match thread.detailed_summary_state {
297                crate::legacy_thread::DetailedSummaryState::NotGenerated
298                | crate::legacy_thread::DetailedSummaryState::Generating => None,
299                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
300            },
301            initial_project_snapshot: thread.initial_project_snapshot,
302            cumulative_token_usage: thread.cumulative_token_usage,
303            request_token_usage,
304            model: thread.model,
305            profile: thread.profile,
306            imported: false,
307            subagent_context: None,
308            speed: None,
309            thinking_enabled: false,
310            thinking_effort: None,
311            draft_prompt: None,
312            ui_scroll_position: None,
313        })
314    }
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
318pub enum DataType {
319    #[serde(rename = "json")]
320    Json,
321    #[serde(rename = "zstd")]
322    Zstd,
323}
324
325impl Bind for DataType {
326    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
327        let value = match self {
328            DataType::Json => "json",
329            DataType::Zstd => "zstd",
330        };
331        value.bind(statement, start_index)
332    }
333}
334
335impl Column for DataType {
336    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
337        let (value, next_index) = String::column(statement, start_index)?;
338        let data_type = match value.as_str() {
339            "json" => DataType::Json,
340            "zstd" => DataType::Zstd,
341            _ => anyhow::bail!("Unknown data type: {}", value),
342        };
343        Ok((data_type, next_index))
344    }
345}
346
347pub(crate) struct ThreadsDatabase {
348    executor: BackgroundExecutor,
349    connection: Arc<Mutex<Connection>>,
350}
351
352struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
353
354impl Global for GlobalThreadsDatabase {}
355
356impl ThreadsDatabase {
357    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
358        if cx.has_global::<GlobalThreadsDatabase>() {
359            return cx.global::<GlobalThreadsDatabase>().0.clone();
360        }
361        let executor = cx.background_executor().clone();
362        let task = executor
363            .spawn({
364                let executor = executor.clone();
365                async move {
366                    match ThreadsDatabase::new(executor) {
367                        Ok(db) => Ok(Arc::new(db)),
368                        Err(err) => Err(Arc::new(err)),
369                    }
370                }
371            })
372            .shared();
373
374        cx.set_global(GlobalThreadsDatabase(task.clone()));
375        task
376    }
377
378    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
379        let connection = if *ZED_STATELESS {
380            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
381        } else if cfg!(any(feature = "test-support", test)) {
382            // rust stores the name of the test on the current thread.
383            // We use this to automatically create a database that will
384            // be shared within the test (for the test_retrieve_old_thread)
385            // but not with concurrent tests.
386            let thread = std::thread::current();
387            let test_name = thread.name();
388            Connection::open_memory(Some(&format!(
389                "THREAD_FALLBACK_{}",
390                test_name.unwrap_or_default()
391            )))
392        } else {
393            let threads_dir = paths::data_dir().join("threads");
394            std::fs::create_dir_all(&threads_dir)?;
395            let sqlite_path = threads_dir.join("threads.db");
396            Connection::open_file(&sqlite_path.to_string_lossy())
397        };
398
399        connection.exec(indoc! {"
400            CREATE TABLE IF NOT EXISTS threads (
401                id TEXT PRIMARY KEY,
402                summary TEXT NOT NULL,
403                updated_at TEXT NOT NULL,
404                data_type TEXT NOT NULL,
405                data BLOB NOT NULL
406            )
407        "})?()
408        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
409
410        if let Ok(mut s) = connection.exec(indoc! {"
411            ALTER TABLE threads ADD COLUMN parent_id TEXT
412        "})
413        {
414            s().ok();
415        }
416
417        if let Ok(mut s) = connection.exec(indoc! {"
418            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
419            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
420        "})
421        {
422            s().ok();
423        }
424
425        if let Ok(mut s) = connection.exec(indoc! {"
426            ALTER TABLE threads ADD COLUMN created_at TEXT;
427        "})
428        {
429            if s().is_ok() {
430                connection.exec(indoc! {"
431                    UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
432                "})?()?;
433            }
434        }
435
436        let db = Self {
437            executor,
438            connection: Arc::new(Mutex::new(connection)),
439        };
440
441        Ok(db)
442    }
443
444    fn save_thread_sync(
445        connection: &Arc<Mutex<Connection>>,
446        id: acp::SessionId,
447        thread: DbThread,
448        folder_paths: &PathList,
449    ) -> Result<()> {
450        const COMPRESSION_LEVEL: i32 = 3;
451
452        #[derive(Serialize)]
453        struct SerializedThread {
454            #[serde(flatten)]
455            thread: DbThread,
456            version: &'static str,
457        }
458
459        let title = thread.title.to_string();
460        let updated_at = thread.updated_at.to_rfc3339();
461        let parent_id = thread
462            .subagent_context
463            .as_ref()
464            .map(|ctx| ctx.parent_thread_id.0.clone());
465        let serialized_folder_paths = folder_paths.serialize();
466        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
467            if folder_paths.is_empty() {
468                (None, None)
469            } else {
470                (
471                    Some(serialized_folder_paths.paths),
472                    Some(serialized_folder_paths.order),
473                )
474            };
475        let json_data = serde_json::to_string(&SerializedThread {
476            thread,
477            version: DbThread::VERSION,
478        })?;
479
480        let connection = connection.lock();
481
482        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
483        let data_type = DataType::Zstd;
484        let data = compressed;
485
486        // Use the thread's updated_at as created_at for new threads.
487        // This ensures the creation time reflects when the thread was conceptually
488        // created, not when it was saved to the database.
489        let created_at = updated_at.clone();
490
491        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
492            INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
493            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
494            ON CONFLICT(id) DO UPDATE SET
495                parent_id = excluded.parent_id,
496                folder_paths = excluded.folder_paths,
497                folder_paths_order = excluded.folder_paths_order,
498                summary = excluded.summary,
499                updated_at = excluded.updated_at,
500                data_type = excluded.data_type,
501                data = excluded.data
502        "})?;
503
504        insert((
505            id.0,
506            parent_id,
507            folder_paths_str,
508            folder_paths_order_str,
509            title,
510            updated_at,
511            data_type,
512            data,
513            created_at,
514        ))?;
515
516        Ok(())
517    }
518
519    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
520        let connection = self.connection.clone();
521
522        self.executor.spawn(async move {
523            let connection = connection.lock();
524
525            let mut select = connection
526                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
527                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
528            "})?;
529
530            let rows = select(())?;
531            let mut threads = Vec::new();
532
533            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
534                let folder_paths = folder_paths
535                    .map(|paths| {
536                        PathList::deserialize(&util::path_list::SerializedPathList {
537                            paths,
538                            order: folder_paths_order.unwrap_or_default(),
539                        })
540                    })
541                    .unwrap_or_default();
542                let created_at = created_at
543                    .as_deref()
544                    .map(DateTime::parse_from_rfc3339)
545                    .transpose()?
546                    .map(|dt| dt.with_timezone(&Utc));
547
548                threads.push(DbThreadMetadata {
549                    id: acp::SessionId::new(id),
550                    parent_session_id: parent_id.map(acp::SessionId::new),
551                    title: summary.into(),
552                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
553                    created_at,
554                    folder_paths,
555                });
556            }
557
558            Ok(threads)
559        })
560    }
561
562    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
563        let connection = self.connection.clone();
564
565        self.executor.spawn(async move {
566            let connection = connection.lock();
567            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
568                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
569            "})?;
570
571            let rows = select(id.0)?;
572            if let Some((data_type, data)) = rows.into_iter().next() {
573                let json_data = match data_type {
574                    DataType::Zstd => {
575                        let decompressed = zstd::decode_all(&data[..])?;
576                        String::from_utf8(decompressed)?
577                    }
578                    DataType::Json => String::from_utf8(data)?,
579                };
580                let thread = DbThread::from_json(json_data.as_bytes())?;
581                Ok(Some(thread))
582            } else {
583                Ok(None)
584            }
585        })
586    }
587
588    pub fn save_thread(
589        &self,
590        id: acp::SessionId,
591        thread: DbThread,
592        folder_paths: PathList,
593    ) -> Task<Result<()>> {
594        let connection = self.connection.clone();
595
596        self.executor
597            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
598    }
599
600    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
601        let connection = self.connection.clone();
602
603        self.executor.spawn(async move {
604            let connection = connection.lock();
605
606            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
607                DELETE FROM threads WHERE id = ?
608            "})?;
609
610            delete(id.0)?;
611
612            Ok(())
613        })
614    }
615
616    pub fn delete_threads(&self) -> Task<Result<()>> {
617        let connection = self.connection.clone();
618
619        self.executor.spawn(async move {
620            let connection = connection.lock();
621
622            let mut delete = connection.exec_bound::<()>(indoc! {"
623                DELETE FROM threads
624            "})?;
625
626            delete(())?;
627
628            Ok(())
629        })
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636    use chrono::{DateTime, TimeZone, Utc};
637    use collections::HashMap;
638    use gpui::TestAppContext;
639    use std::sync::Arc;
640
641    #[test]
642    fn test_shared_thread_roundtrip() {
643        let original = SharedThread {
644            title: "Test Thread".into(),
645            messages: vec![],
646            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
647            model: None,
648            version: SharedThread::VERSION.to_string(),
649        };
650
651        let bytes = original.to_bytes().expect("Failed to serialize");
652        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
653
654        assert_eq!(restored.title, original.title);
655        assert_eq!(restored.version, original.version);
656        assert_eq!(restored.updated_at, original.updated_at);
657    }
658
659    #[test]
660    fn test_imported_flag_defaults_to_false() {
661        // Simulate deserializing a thread without the imported field (backwards compatibility).
662        let json = r#"{
663            "title": "Old Thread",
664            "messages": [],
665            "updated_at": "2024-01-01T00:00:00Z"
666        }"#;
667
668        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
669
670        assert!(
671            !db_thread.imported,
672            "Legacy threads without imported field should default to false"
673        );
674    }
675
676    fn session_id(value: &str) -> acp::SessionId {
677        acp::SessionId::new(Arc::<str>::from(value))
678    }
679
680    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
681        DbThread {
682            title: title.to_string().into(),
683            messages: Vec::new(),
684            updated_at,
685            detailed_summary: None,
686            initial_project_snapshot: None,
687            cumulative_token_usage: Default::default(),
688            request_token_usage: HashMap::default(),
689            model: None,
690            profile: None,
691            imported: false,
692            subagent_context: None,
693            speed: None,
694            thinking_enabled: false,
695            thinking_effort: None,
696            draft_prompt: None,
697            ui_scroll_position: None,
698        }
699    }
700
701    #[gpui::test]
702    async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
703        let database = ThreadsDatabase::new(cx.executor()).unwrap();
704
705        let older_id = session_id("thread-a");
706        let newer_id = session_id("thread-b");
707
708        let older_thread = make_thread(
709            "Thread A",
710            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
711        );
712        let newer_thread = make_thread(
713            "Thread B",
714            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
715        );
716
717        database
718            .save_thread(older_id.clone(), older_thread, PathList::default())
719            .await
720            .unwrap();
721        database
722            .save_thread(newer_id.clone(), newer_thread, PathList::default())
723            .await
724            .unwrap();
725
726        let entries = database.list_threads().await.unwrap();
727        assert_eq!(entries.len(), 2);
728        assert_eq!(entries[0].id, newer_id);
729        assert_eq!(entries[1].id, older_id);
730    }
731
732    #[gpui::test]
733    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
734        let database = ThreadsDatabase::new(cx.executor()).unwrap();
735
736        let thread_id = session_id("thread-a");
737        let original_thread = make_thread(
738            "Thread A",
739            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
740        );
741        let updated_thread = make_thread(
742            "Thread B",
743            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
744        );
745
746        database
747            .save_thread(thread_id.clone(), original_thread, PathList::default())
748            .await
749            .unwrap();
750        database
751            .save_thread(thread_id.clone(), updated_thread, PathList::default())
752            .await
753            .unwrap();
754
755        let entries = database.list_threads().await.unwrap();
756        assert_eq!(entries.len(), 1);
757        assert_eq!(entries[0].id, thread_id);
758        assert_eq!(entries[0].title.as_ref(), "Thread B");
759        assert_eq!(
760            entries[0].updated_at,
761            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
762        );
763        assert!(
764            entries[0].created_at.is_some(),
765            "created_at should be populated"
766        );
767    }
768
769    #[test]
770    fn test_subagent_context_defaults_to_none() {
771        let json = r#"{
772            "title": "Old Thread",
773            "messages": [],
774            "updated_at": "2024-01-01T00:00:00Z"
775        }"#;
776
777        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
778
779        assert!(
780            db_thread.subagent_context.is_none(),
781            "Legacy threads without subagent_context should default to None"
782        );
783    }
784
785    #[test]
786    fn test_draft_prompt_defaults_to_none() {
787        let json = r#"{
788            "title": "Old Thread",
789            "messages": [],
790            "updated_at": "2024-01-01T00:00:00Z"
791        }"#;
792
793        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
794
795        assert!(
796            db_thread.draft_prompt.is_none(),
797            "Legacy threads without draft_prompt field should default to None"
798        );
799    }
800
801    #[gpui::test]
802    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
803        let database = ThreadsDatabase::new(cx.executor()).unwrap();
804
805        let parent_id = session_id("parent-thread");
806        let child_id = session_id("child-thread");
807
808        let mut child_thread = make_thread(
809            "Subagent Thread",
810            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
811        );
812        child_thread.subagent_context = Some(crate::SubagentContext {
813            parent_thread_id: parent_id.clone(),
814            depth: 2,
815        });
816
817        database
818            .save_thread(child_id.clone(), child_thread, PathList::default())
819            .await
820            .unwrap();
821
822        let loaded = database
823            .load_thread(child_id)
824            .await
825            .unwrap()
826            .expect("thread should exist");
827
828        let context = loaded
829            .subagent_context
830            .expect("subagent_context should be restored");
831        assert_eq!(context.parent_thread_id, parent_id);
832        assert_eq!(context.depth, 2);
833    }
834
835    #[gpui::test]
836    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
837        let database = ThreadsDatabase::new(cx.executor()).unwrap();
838
839        let thread_id = session_id("regular-thread");
840        let thread = make_thread(
841            "Regular Thread",
842            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
843        );
844
845        database
846            .save_thread(thread_id.clone(), thread, PathList::default())
847            .await
848            .unwrap();
849
850        let loaded = database
851            .load_thread(thread_id)
852            .await
853            .unwrap()
854            .expect("thread should exist");
855
856        assert!(
857            loaded.subagent_context.is_none(),
858            "Regular threads should have no subagent_context"
859        );
860    }
861
862    #[gpui::test]
863    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
864        let database = ThreadsDatabase::new(cx.executor()).unwrap();
865
866        let thread_id = session_id("folder-thread");
867        let thread = make_thread(
868            "Folder Thread",
869            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
870        );
871
872        let folder_paths = PathList::new(&[
873            std::path::PathBuf::from("/home/user/project-a"),
874            std::path::PathBuf::from("/home/user/project-b"),
875        ]);
876
877        database
878            .save_thread(thread_id.clone(), thread, folder_paths.clone())
879            .await
880            .unwrap();
881
882        let threads = database.list_threads().await.unwrap();
883        assert_eq!(threads.len(), 1);
884        assert_eq!(threads[0].folder_paths, folder_paths);
885    }
886
887    #[gpui::test]
888    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
889        let database = ThreadsDatabase::new(cx.executor()).unwrap();
890
891        let thread_id = session_id("no-folder-thread");
892        let thread = make_thread(
893            "No Folder Thread",
894            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
895        );
896
897        database
898            .save_thread(thread_id.clone(), thread, PathList::default())
899            .await
900            .unwrap();
901
902        let threads = database.list_threads().await.unwrap();
903        assert_eq!(threads.len(), 1);
904        assert!(threads[0].folder_paths.is_empty());
905    }
906
907    #[test]
908    fn test_scroll_position_defaults_to_none() {
909        let json = r#"{
910            "title": "Old Thread",
911            "messages": [],
912            "updated_at": "2024-01-01T00:00:00Z"
913        }"#;
914
915        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
916
917        assert!(
918            db_thread.ui_scroll_position.is_none(),
919            "Legacy threads without scroll_position field should default to None"
920        );
921    }
922
923    #[gpui::test]
924    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
925        let database = ThreadsDatabase::new(cx.executor()).unwrap();
926
927        let thread_id = session_id("thread-with-scroll");
928
929        let mut thread = make_thread(
930            "Thread With Scroll",
931            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
932        );
933        thread.ui_scroll_position = Some(SerializedScrollPosition {
934            item_ix: 42,
935            offset_in_item: 13.5,
936        });
937
938        database
939            .save_thread(thread_id.clone(), thread, PathList::default())
940            .await
941            .unwrap();
942
943        let loaded = database
944            .load_thread(thread_id)
945            .await
946            .unwrap()
947            .expect("thread should exist");
948
949        let scroll = loaded
950            .ui_scroll_position
951            .expect("scroll_position should be restored");
952        assert_eq!(scroll.item_ix, 42);
953        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
954    }
955}