db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    pub created_at: Option<DateTime<Utc>>,
 36    /// The workspace folder paths this thread was created against, sorted
 37    /// lexicographically. Used for grouping threads by project in the sidebar.
 38    pub folder_paths: PathList,
 39}
 40
 41impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
 42    fn from(meta: &DbThreadMetadata) -> Self {
 43        Self {
 44            session_id: meta.id.clone(),
 45            cwd: None,
 46            title: Some(meta.title.clone()),
 47            updated_at: Some(meta.updated_at),
 48            meta: None,
 49        }
 50    }
 51}
 52
 53#[derive(Debug, Serialize, Deserialize)]
 54pub struct DbThread {
 55    pub title: SharedString,
 56    pub messages: Vec<DbMessage>,
 57    pub updated_at: DateTime<Utc>,
 58    #[serde(default)]
 59    pub detailed_summary: Option<SharedString>,
 60    #[serde(default)]
 61    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 62    #[serde(default)]
 63    pub cumulative_token_usage: language_model::TokenUsage,
 64    #[serde(default)]
 65    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 66    #[serde(default)]
 67    pub model: Option<DbLanguageModel>,
 68    #[serde(default)]
 69    pub profile: Option<AgentProfileId>,
 70    #[serde(default)]
 71    pub imported: bool,
 72    #[serde(default)]
 73    pub subagent_context: Option<crate::SubagentContext>,
 74    #[serde(default)]
 75    pub speed: Option<Speed>,
 76    #[serde(default)]
 77    pub thinking_enabled: bool,
 78    #[serde(default)]
 79    pub thinking_effort: Option<String>,
 80    #[serde(default)]
 81    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 82    #[serde(default)]
 83    pub ui_scroll_position: Option<SerializedScrollPosition>,
 84}
 85
 86#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 87pub struct SerializedScrollPosition {
 88    pub item_ix: usize,
 89    pub offset_in_item: f32,
 90}
 91
 92#[derive(Debug, Clone, Serialize, Deserialize)]
 93pub struct SharedThread {
 94    pub title: SharedString,
 95    pub messages: Vec<DbMessage>,
 96    pub updated_at: DateTime<Utc>,
 97    #[serde(default)]
 98    pub model: Option<DbLanguageModel>,
 99    pub version: String,
100}
101
102impl SharedThread {
103    pub const VERSION: &'static str = "1.0.0";
104
105    pub fn from_db_thread(thread: &DbThread) -> Self {
106        Self {
107            title: thread.title.clone(),
108            messages: thread.messages.clone(),
109            updated_at: thread.updated_at,
110            model: thread.model.clone(),
111            version: Self::VERSION.to_string(),
112        }
113    }
114
115    pub fn to_db_thread(self) -> DbThread {
116        DbThread {
117            title: format!("🔗 {}", self.title).into(),
118            messages: self.messages,
119            updated_at: self.updated_at,
120            detailed_summary: None,
121            initial_project_snapshot: None,
122            cumulative_token_usage: Default::default(),
123            request_token_usage: Default::default(),
124            model: self.model,
125            profile: None,
126            imported: true,
127            subagent_context: None,
128            speed: None,
129            thinking_enabled: false,
130            thinking_effort: None,
131            draft_prompt: None,
132            ui_scroll_position: None,
133        }
134    }
135
136    pub fn to_bytes(&self) -> Result<Vec<u8>> {
137        const COMPRESSION_LEVEL: i32 = 3;
138        let json = serde_json::to_vec(self)?;
139        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
140        Ok(compressed)
141    }
142
143    pub fn from_bytes(data: &[u8]) -> Result<Self> {
144        let decompressed = zstd::decode_all(data)?;
145        Ok(serde_json::from_slice(&decompressed)?)
146    }
147}
148
149impl DbThread {
150    pub const VERSION: &'static str = "0.3.0";
151
152    pub fn from_json(json: &[u8]) -> Result<Self> {
153        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
154        match saved_thread_json.get("version") {
155            Some(serde_json::Value::String(version)) => match version.as_str() {
156                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
157                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
158                    json,
159                )?),
160            },
161            _ => {
162                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
163            }
164        }
165    }
166
167    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
168        let mut messages = Vec::new();
169        let mut request_token_usage = HashMap::default();
170
171        let mut last_user_message_id = None;
172        for (ix, msg) in thread.messages.into_iter().enumerate() {
173            let message = match msg.role {
174                language_model::Role::User => {
175                    let mut content = Vec::new();
176
177                    // Convert segments to content
178                    for segment in msg.segments {
179                        match segment {
180                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
181                                content.push(UserMessageContent::Text(text));
182                            }
183                            crate::legacy_thread::SerializedMessageSegment::Thinking {
184                                text,
185                                ..
186                            } => {
187                                // User messages don't have thinking segments, but handle gracefully
188                                content.push(UserMessageContent::Text(text));
189                            }
190                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
191                                ..
192                            } => {
193                                // User messages don't have redacted thinking, skip.
194                            }
195                        }
196                    }
197
198                    // If no content was added, add context as text if available
199                    if content.is_empty() && !msg.context.is_empty() {
200                        content.push(UserMessageContent::Text(msg.context));
201                    }
202
203                    let id = UserMessageId::new();
204                    last_user_message_id = Some(id.clone());
205
206                    crate::Message::User(UserMessage {
207                        // MessageId from old format can't be meaningfully converted, so generate a new one
208                        id,
209                        content,
210                    })
211                }
212                language_model::Role::Assistant => {
213                    let mut content = Vec::new();
214
215                    // Convert segments to content
216                    for segment in msg.segments {
217                        match segment {
218                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
219                                content.push(AgentMessageContent::Text(text));
220                            }
221                            crate::legacy_thread::SerializedMessageSegment::Thinking {
222                                text,
223                                signature,
224                            } => {
225                                content.push(AgentMessageContent::Thinking { text, signature });
226                            }
227                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
228                                data,
229                            } => {
230                                content.push(AgentMessageContent::RedactedThinking(data));
231                            }
232                        }
233                    }
234
235                    // Convert tool uses
236                    let mut tool_names_by_id = HashMap::default();
237                    for tool_use in msg.tool_uses {
238                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
239                        content.push(AgentMessageContent::ToolUse(
240                            language_model::LanguageModelToolUse {
241                                id: tool_use.id,
242                                name: tool_use.name.into(),
243                                raw_input: serde_json::to_string(&tool_use.input)
244                                    .unwrap_or_default(),
245                                input: tool_use.input,
246                                is_input_complete: true,
247                                thought_signature: None,
248                            },
249                        ));
250                    }
251
252                    // Convert tool results
253                    let mut tool_results = IndexMap::default();
254                    for tool_result in msg.tool_results {
255                        let name = tool_names_by_id
256                            .remove(&tool_result.tool_use_id)
257                            .unwrap_or_else(|| SharedString::from("unknown"));
258                        tool_results.insert(
259                            tool_result.tool_use_id.clone(),
260                            language_model::LanguageModelToolResult {
261                                tool_use_id: tool_result.tool_use_id,
262                                tool_name: name.into(),
263                                is_error: tool_result.is_error,
264                                content: tool_result.content,
265                                output: tool_result.output,
266                            },
267                        );
268                    }
269
270                    if let Some(last_user_message_id) = &last_user_message_id
271                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
272                    {
273                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
274                    }
275
276                    crate::Message::Agent(AgentMessage {
277                        content,
278                        tool_results,
279                        reasoning_details: None,
280                    })
281                }
282                language_model::Role::System => {
283                    // Skip system messages as they're not supported in the new format
284                    continue;
285                }
286            };
287
288            messages.push(message);
289        }
290
291        Ok(Self {
292            title: thread.summary,
293            messages,
294            updated_at: thread.updated_at,
295            detailed_summary: match thread.detailed_summary_state {
296                crate::legacy_thread::DetailedSummaryState::NotGenerated
297                | crate::legacy_thread::DetailedSummaryState::Generating => None,
298                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
299            },
300            initial_project_snapshot: thread.initial_project_snapshot,
301            cumulative_token_usage: thread.cumulative_token_usage,
302            request_token_usage,
303            model: thread.model,
304            profile: thread.profile,
305            imported: false,
306            subagent_context: None,
307            speed: None,
308            thinking_enabled: false,
309            thinking_effort: None,
310            draft_prompt: None,
311            ui_scroll_position: None,
312        })
313    }
314}
315
316#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
317pub enum DataType {
318    #[serde(rename = "json")]
319    Json,
320    #[serde(rename = "zstd")]
321    Zstd,
322}
323
324impl Bind for DataType {
325    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
326        let value = match self {
327            DataType::Json => "json",
328            DataType::Zstd => "zstd",
329        };
330        value.bind(statement, start_index)
331    }
332}
333
334impl Column for DataType {
335    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
336        let (value, next_index) = String::column(statement, start_index)?;
337        let data_type = match value.as_str() {
338            "json" => DataType::Json,
339            "zstd" => DataType::Zstd,
340            _ => anyhow::bail!("Unknown data type: {}", value),
341        };
342        Ok((data_type, next_index))
343    }
344}
345
346pub(crate) struct ThreadsDatabase {
347    executor: BackgroundExecutor,
348    connection: Arc<Mutex<Connection>>,
349}
350
351struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
352
353impl Global for GlobalThreadsDatabase {}
354
355impl ThreadsDatabase {
356    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
357        if cx.has_global::<GlobalThreadsDatabase>() {
358            return cx.global::<GlobalThreadsDatabase>().0.clone();
359        }
360        let executor = cx.background_executor().clone();
361        let task = executor
362            .spawn({
363                let executor = executor.clone();
364                async move {
365                    match ThreadsDatabase::new(executor) {
366                        Ok(db) => Ok(Arc::new(db)),
367                        Err(err) => Err(Arc::new(err)),
368                    }
369                }
370            })
371            .shared();
372
373        cx.set_global(GlobalThreadsDatabase(task.clone()));
374        task
375    }
376
377    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
378        let connection = if *ZED_STATELESS {
379            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
380        } else if cfg!(any(feature = "test-support", test)) {
381            // rust stores the name of the test on the current thread.
382            // We use this to automatically create a database that will
383            // be shared within the test (for the test_retrieve_old_thread)
384            // but not with concurrent tests.
385            let thread = std::thread::current();
386            let test_name = thread.name();
387            Connection::open_memory(Some(&format!(
388                "THREAD_FALLBACK_{}",
389                test_name.unwrap_or_default()
390            )))
391        } else {
392            let threads_dir = paths::data_dir().join("threads");
393            std::fs::create_dir_all(&threads_dir)?;
394            let sqlite_path = threads_dir.join("threads.db");
395            Connection::open_file(&sqlite_path.to_string_lossy())
396        };
397
398        connection.exec(indoc! {"
399            CREATE TABLE IF NOT EXISTS threads (
400                id TEXT PRIMARY KEY,
401                summary TEXT NOT NULL,
402                updated_at TEXT NOT NULL,
403                data_type TEXT NOT NULL,
404                data BLOB NOT NULL
405            )
406        "})?()
407        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
408
409        if let Ok(mut s) = connection.exec(indoc! {"
410            ALTER TABLE threads ADD COLUMN parent_id TEXT
411        "})
412        {
413            s().ok();
414        }
415
416        if let Ok(mut s) = connection.exec(indoc! {"
417            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
418            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
419        "})
420        {
421            s().ok();
422        }
423
424        if let Ok(mut s) = connection.exec(indoc! {"
425            ALTER TABLE threads ADD COLUMN created_at TEXT;
426        "})
427        {
428            if s().is_ok() {
429                connection.exec(indoc! {"
430                    UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
431                "})?()?;
432            }
433        }
434
435        let db = Self {
436            executor,
437            connection: Arc::new(Mutex::new(connection)),
438        };
439
440        Ok(db)
441    }
442
443    fn save_thread_sync(
444        connection: &Arc<Mutex<Connection>>,
445        id: acp::SessionId,
446        thread: DbThread,
447        folder_paths: &PathList,
448    ) -> Result<()> {
449        const COMPRESSION_LEVEL: i32 = 3;
450
451        #[derive(Serialize)]
452        struct SerializedThread {
453            #[serde(flatten)]
454            thread: DbThread,
455            version: &'static str,
456        }
457
458        let title = thread.title.to_string();
459        let updated_at = thread.updated_at.to_rfc3339();
460        let parent_id = thread
461            .subagent_context
462            .as_ref()
463            .map(|ctx| ctx.parent_thread_id.0.clone());
464        let serialized_folder_paths = folder_paths.serialize();
465        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
466            if folder_paths.is_empty() {
467                (None, None)
468            } else {
469                (
470                    Some(serialized_folder_paths.paths),
471                    Some(serialized_folder_paths.order),
472                )
473            };
474        let json_data = serde_json::to_string(&SerializedThread {
475            thread,
476            version: DbThread::VERSION,
477        })?;
478
479        let connection = connection.lock();
480
481        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
482        let data_type = DataType::Zstd;
483        let data = compressed;
484
485        let created_at = Utc::now().to_rfc3339();
486
487        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
488            INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
489            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
490            ON CONFLICT(id) DO UPDATE SET
491                parent_id = excluded.parent_id,
492                folder_paths = excluded.folder_paths,
493                folder_paths_order = excluded.folder_paths_order,
494                summary = excluded.summary,
495                updated_at = excluded.updated_at,
496                data_type = excluded.data_type,
497                data = excluded.data
498        "})?;
499
500        insert((
501            id.0,
502            parent_id,
503            folder_paths_str,
504            folder_paths_order_str,
505            title,
506            updated_at,
507            data_type,
508            data,
509            created_at,
510        ))?;
511
512        Ok(())
513    }
514
515    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
516        let connection = self.connection.clone();
517
518        self.executor.spawn(async move {
519            let connection = connection.lock();
520
521            let mut select = connection
522                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
523                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
524            "})?;
525
526            let rows = select(())?;
527            let mut threads = Vec::new();
528
529            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
530                let folder_paths = folder_paths
531                    .map(|paths| {
532                        PathList::deserialize(&util::path_list::SerializedPathList {
533                            paths,
534                            order: folder_paths_order.unwrap_or_default(),
535                        })
536                    })
537                    .unwrap_or_default();
538                let created_at = created_at
539                    .as_deref()
540                    .map(DateTime::parse_from_rfc3339)
541                    .transpose()?
542                    .map(|dt| dt.with_timezone(&Utc));
543
544                threads.push(DbThreadMetadata {
545                    id: acp::SessionId::new(id),
546                    parent_session_id: parent_id.map(acp::SessionId::new),
547                    title: summary.into(),
548                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
549                    created_at,
550                    folder_paths,
551                });
552            }
553
554            Ok(threads)
555        })
556    }
557
558    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
559        let connection = self.connection.clone();
560
561        self.executor.spawn(async move {
562            let connection = connection.lock();
563            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
564                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
565            "})?;
566
567            let rows = select(id.0)?;
568            if let Some((data_type, data)) = rows.into_iter().next() {
569                let json_data = match data_type {
570                    DataType::Zstd => {
571                        let decompressed = zstd::decode_all(&data[..])?;
572                        String::from_utf8(decompressed)?
573                    }
574                    DataType::Json => String::from_utf8(data)?,
575                };
576                let thread = DbThread::from_json(json_data.as_bytes())?;
577                Ok(Some(thread))
578            } else {
579                Ok(None)
580            }
581        })
582    }
583
584    pub fn save_thread(
585        &self,
586        id: acp::SessionId,
587        thread: DbThread,
588        folder_paths: PathList,
589    ) -> Task<Result<()>> {
590        let connection = self.connection.clone();
591
592        self.executor
593            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
594    }
595
596    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
597        let connection = self.connection.clone();
598
599        self.executor.spawn(async move {
600            let connection = connection.lock();
601
602            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
603                DELETE FROM threads WHERE id = ?
604            "})?;
605
606            delete(id.0)?;
607
608            Ok(())
609        })
610    }
611
612    pub fn delete_threads(&self) -> Task<Result<()>> {
613        let connection = self.connection.clone();
614
615        self.executor.spawn(async move {
616            let connection = connection.lock();
617
618            let mut delete = connection.exec_bound::<()>(indoc! {"
619                DELETE FROM threads
620            "})?;
621
622            delete(())?;
623
624            Ok(())
625        })
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use chrono::{DateTime, TimeZone, Utc};
633    use collections::HashMap;
634    use gpui::TestAppContext;
635    use std::sync::Arc;
636
637    #[test]
638    fn test_shared_thread_roundtrip() {
639        let original = SharedThread {
640            title: "Test Thread".into(),
641            messages: vec![],
642            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
643            model: None,
644            version: SharedThread::VERSION.to_string(),
645        };
646
647        let bytes = original.to_bytes().expect("Failed to serialize");
648        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
649
650        assert_eq!(restored.title, original.title);
651        assert_eq!(restored.version, original.version);
652        assert_eq!(restored.updated_at, original.updated_at);
653    }
654
655    #[test]
656    fn test_imported_flag_defaults_to_false() {
657        // Simulate deserializing a thread without the imported field (backwards compatibility).
658        let json = r#"{
659            "title": "Old Thread",
660            "messages": [],
661            "updated_at": "2024-01-01T00:00:00Z"
662        }"#;
663
664        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
665
666        assert!(
667            !db_thread.imported,
668            "Legacy threads without imported field should default to false"
669        );
670    }
671
672    fn session_id(value: &str) -> acp::SessionId {
673        acp::SessionId::new(Arc::<str>::from(value))
674    }
675
676    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
677        DbThread {
678            title: title.to_string().into(),
679            messages: Vec::new(),
680            updated_at,
681            detailed_summary: None,
682            initial_project_snapshot: None,
683            cumulative_token_usage: Default::default(),
684            request_token_usage: HashMap::default(),
685            model: None,
686            profile: None,
687            imported: false,
688            subagent_context: None,
689            speed: None,
690            thinking_enabled: false,
691            thinking_effort: None,
692            draft_prompt: None,
693            ui_scroll_position: None,
694        }
695    }
696
697    #[gpui::test]
698    async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
699        let database = ThreadsDatabase::new(cx.executor()).unwrap();
700
701        let older_id = session_id("thread-a");
702        let newer_id = session_id("thread-b");
703
704        let older_thread = make_thread(
705            "Thread A",
706            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
707        );
708        let newer_thread = make_thread(
709            "Thread B",
710            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
711        );
712
713        database
714            .save_thread(older_id.clone(), older_thread, PathList::default())
715            .await
716            .unwrap();
717        database
718            .save_thread(newer_id.clone(), newer_thread, PathList::default())
719            .await
720            .unwrap();
721
722        let entries = database.list_threads().await.unwrap();
723        assert_eq!(entries.len(), 2);
724        assert_eq!(entries[0].id, newer_id);
725        assert_eq!(entries[1].id, older_id);
726    }
727
728    #[gpui::test]
729    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
730        let database = ThreadsDatabase::new(cx.executor()).unwrap();
731
732        let thread_id = session_id("thread-a");
733        let original_thread = make_thread(
734            "Thread A",
735            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
736        );
737        let updated_thread = make_thread(
738            "Thread B",
739            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
740        );
741
742        database
743            .save_thread(thread_id.clone(), original_thread, PathList::default())
744            .await
745            .unwrap();
746        database
747            .save_thread(thread_id.clone(), updated_thread, PathList::default())
748            .await
749            .unwrap();
750
751        let entries = database.list_threads().await.unwrap();
752        assert_eq!(entries.len(), 1);
753        assert_eq!(entries[0].id, thread_id);
754        assert_eq!(entries[0].title.as_ref(), "Thread B");
755        assert_eq!(
756            entries[0].updated_at,
757            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
758        );
759        assert!(
760            entries[0].created_at.is_some(),
761            "created_at should be populated"
762        );
763    }
764
765    #[test]
766    fn test_subagent_context_defaults_to_none() {
767        let json = r#"{
768            "title": "Old Thread",
769            "messages": [],
770            "updated_at": "2024-01-01T00:00:00Z"
771        }"#;
772
773        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
774
775        assert!(
776            db_thread.subagent_context.is_none(),
777            "Legacy threads without subagent_context should default to None"
778        );
779    }
780
781    #[test]
782    fn test_draft_prompt_defaults_to_none() {
783        let json = r#"{
784            "title": "Old Thread",
785            "messages": [],
786            "updated_at": "2024-01-01T00:00:00Z"
787        }"#;
788
789        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
790
791        assert!(
792            db_thread.draft_prompt.is_none(),
793            "Legacy threads without draft_prompt field should default to None"
794        );
795    }
796
797    #[gpui::test]
798    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
799        let database = ThreadsDatabase::new(cx.executor()).unwrap();
800
801        let parent_id = session_id("parent-thread");
802        let child_id = session_id("child-thread");
803
804        let mut child_thread = make_thread(
805            "Subagent Thread",
806            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
807        );
808        child_thread.subagent_context = Some(crate::SubagentContext {
809            parent_thread_id: parent_id.clone(),
810            depth: 2,
811        });
812
813        database
814            .save_thread(child_id.clone(), child_thread, PathList::default())
815            .await
816            .unwrap();
817
818        let loaded = database
819            .load_thread(child_id)
820            .await
821            .unwrap()
822            .expect("thread should exist");
823
824        let context = loaded
825            .subagent_context
826            .expect("subagent_context should be restored");
827        assert_eq!(context.parent_thread_id, parent_id);
828        assert_eq!(context.depth, 2);
829    }
830
831    #[gpui::test]
832    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
833        let database = ThreadsDatabase::new(cx.executor()).unwrap();
834
835        let thread_id = session_id("regular-thread");
836        let thread = make_thread(
837            "Regular Thread",
838            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
839        );
840
841        database
842            .save_thread(thread_id.clone(), thread, PathList::default())
843            .await
844            .unwrap();
845
846        let loaded = database
847            .load_thread(thread_id)
848            .await
849            .unwrap()
850            .expect("thread should exist");
851
852        assert!(
853            loaded.subagent_context.is_none(),
854            "Regular threads should have no subagent_context"
855        );
856    }
857
858    #[gpui::test]
859    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
860        let database = ThreadsDatabase::new(cx.executor()).unwrap();
861
862        let thread_id = session_id("folder-thread");
863        let thread = make_thread(
864            "Folder Thread",
865            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
866        );
867
868        let folder_paths = PathList::new(&[
869            std::path::PathBuf::from("/home/user/project-a"),
870            std::path::PathBuf::from("/home/user/project-b"),
871        ]);
872
873        database
874            .save_thread(thread_id.clone(), thread, folder_paths.clone())
875            .await
876            .unwrap();
877
878        let threads = database.list_threads().await.unwrap();
879        assert_eq!(threads.len(), 1);
880        assert_eq!(threads[0].folder_paths, folder_paths);
881    }
882
883    #[gpui::test]
884    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
885        let database = ThreadsDatabase::new(cx.executor()).unwrap();
886
887        let thread_id = session_id("no-folder-thread");
888        let thread = make_thread(
889            "No Folder Thread",
890            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
891        );
892
893        database
894            .save_thread(thread_id.clone(), thread, PathList::default())
895            .await
896            .unwrap();
897
898        let threads = database.list_threads().await.unwrap();
899        assert_eq!(threads.len(), 1);
900        assert!(threads[0].folder_paths.is_empty());
901    }
902
903    #[test]
904    fn test_scroll_position_defaults_to_none() {
905        let json = r#"{
906            "title": "Old Thread",
907            "messages": [],
908            "updated_at": "2024-01-01T00:00:00Z"
909        }"#;
910
911        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
912
913        assert!(
914            db_thread.ui_scroll_position.is_none(),
915            "Legacy threads without scroll_position field should default to None"
916        );
917    }
918
919    #[gpui::test]
920    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
921        let database = ThreadsDatabase::new(cx.executor()).unwrap();
922
923        let thread_id = session_id("thread-with-scroll");
924
925        let mut thread = make_thread(
926            "Thread With Scroll",
927            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
928        );
929        thread.ui_scroll_position = Some(SerializedScrollPosition {
930            item_ix: 42,
931            offset_in_item: 13.5,
932        });
933
934        database
935            .save_thread(thread_id.clone(), thread, PathList::default())
936            .await
937            .unwrap();
938
939        let loaded = database
940            .load_thread(thread_id)
941            .await
942            .unwrap()
943            .expect("thread should exist");
944
945        let scroll = loaded
946            .ui_scroll_position
947            .expect("scroll_position should be restored");
948        assert_eq!(scroll.item_ix, 42);
949        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
950    }
951}