db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    pub created_at: Option<DateTime<Utc>>,
 36    /// The workspace folder paths this thread was created against, sorted
 37    /// lexicographically. Used for grouping threads by project in the sidebar.
 38    pub folder_paths: PathList,
 39}
 40
 41impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
 42    fn from(meta: &DbThreadMetadata) -> Self {
 43        Self {
 44            session_id: meta.id.clone(),
 45            cwd: None,
 46            title: Some(meta.title.clone()),
 47            updated_at: Some(meta.updated_at),
 48            created_at: meta.created_at,
 49            meta: None,
 50        }
 51    }
 52}
 53
 54#[derive(Debug, Serialize, Deserialize)]
 55pub struct DbThread {
 56    pub title: SharedString,
 57    pub messages: Vec<DbMessage>,
 58    pub updated_at: DateTime<Utc>,
 59    #[serde(default)]
 60    pub detailed_summary: Option<SharedString>,
 61    #[serde(default)]
 62    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 63    #[serde(default)]
 64    pub cumulative_token_usage: language_model::TokenUsage,
 65    #[serde(default)]
 66    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 67    #[serde(default)]
 68    pub model: Option<DbLanguageModel>,
 69    #[serde(default)]
 70    pub profile: Option<AgentProfileId>,
 71    #[serde(default)]
 72    pub imported: bool,
 73    #[serde(default)]
 74    pub subagent_context: Option<crate::SubagentContext>,
 75    #[serde(default)]
 76    pub speed: Option<Speed>,
 77    #[serde(default)]
 78    pub thinking_enabled: bool,
 79    #[serde(default)]
 80    pub thinking_effort: Option<String>,
 81    #[serde(default)]
 82    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 83    #[serde(default)]
 84    pub provisional_title: Option<SharedString>,
 85    #[serde(default)]
 86    pub ui_scroll_position: Option<SerializedScrollPosition>,
 87}
 88
 89#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 90pub struct SerializedScrollPosition {
 91    pub item_ix: usize,
 92    pub offset_in_item: f32,
 93}
 94
 95#[derive(Debug, Clone, Serialize, Deserialize)]
 96pub struct SharedThread {
 97    pub title: SharedString,
 98    pub messages: Vec<DbMessage>,
 99    pub updated_at: DateTime<Utc>,
100    #[serde(default)]
101    pub model: Option<DbLanguageModel>,
102    pub version: String,
103}
104
105impl SharedThread {
106    pub const VERSION: &'static str = "1.0.0";
107
108    pub fn from_db_thread(thread: &DbThread) -> Self {
109        Self {
110            title: thread.title.clone(),
111            messages: thread.messages.clone(),
112            updated_at: thread.updated_at,
113            model: thread.model.clone(),
114            version: Self::VERSION.to_string(),
115        }
116    }
117
118    pub fn to_db_thread(self) -> DbThread {
119        DbThread {
120            title: format!("🔗 {}", self.title).into(),
121            messages: self.messages,
122            updated_at: self.updated_at,
123            detailed_summary: None,
124            initial_project_snapshot: None,
125            cumulative_token_usage: Default::default(),
126            request_token_usage: Default::default(),
127            model: self.model,
128            profile: None,
129            imported: true,
130            subagent_context: None,
131            speed: None,
132            thinking_enabled: false,
133            thinking_effort: None,
134            draft_prompt: None,
135            provisional_title: None,
136            ui_scroll_position: None,
137        }
138    }
139
140    pub fn to_bytes(&self) -> Result<Vec<u8>> {
141        const COMPRESSION_LEVEL: i32 = 3;
142        let json = serde_json::to_vec(self)?;
143        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
144        Ok(compressed)
145    }
146
147    pub fn from_bytes(data: &[u8]) -> Result<Self> {
148        let decompressed = zstd::decode_all(data)?;
149        Ok(serde_json::from_slice(&decompressed)?)
150    }
151}
152
153impl DbThread {
154    pub const VERSION: &'static str = "0.3.0";
155
156    pub fn from_json(json: &[u8]) -> Result<Self> {
157        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
158        match saved_thread_json.get("version") {
159            Some(serde_json::Value::String(version)) => match version.as_str() {
160                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
161                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
162                    json,
163                )?),
164            },
165            _ => {
166                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
167            }
168        }
169    }
170
171    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
172        let mut messages = Vec::new();
173        let mut request_token_usage = HashMap::default();
174
175        let mut last_user_message_id = None;
176        for (ix, msg) in thread.messages.into_iter().enumerate() {
177            let message = match msg.role {
178                language_model::Role::User => {
179                    let mut content = Vec::new();
180
181                    // Convert segments to content
182                    for segment in msg.segments {
183                        match segment {
184                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
185                                content.push(UserMessageContent::Text(text));
186                            }
187                            crate::legacy_thread::SerializedMessageSegment::Thinking {
188                                text,
189                                ..
190                            } => {
191                                // User messages don't have thinking segments, but handle gracefully
192                                content.push(UserMessageContent::Text(text));
193                            }
194                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
195                                ..
196                            } => {
197                                // User messages don't have redacted thinking, skip.
198                            }
199                        }
200                    }
201
202                    // If no content was added, add context as text if available
203                    if content.is_empty() && !msg.context.is_empty() {
204                        content.push(UserMessageContent::Text(msg.context));
205                    }
206
207                    let id = UserMessageId::new();
208                    last_user_message_id = Some(id.clone());
209
210                    crate::Message::User(UserMessage {
211                        // MessageId from old format can't be meaningfully converted, so generate a new one
212                        id,
213                        content,
214                    })
215                }
216                language_model::Role::Assistant => {
217                    let mut content = Vec::new();
218
219                    // Convert segments to content
220                    for segment in msg.segments {
221                        match segment {
222                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
223                                content.push(AgentMessageContent::Text(text));
224                            }
225                            crate::legacy_thread::SerializedMessageSegment::Thinking {
226                                text,
227                                signature,
228                            } => {
229                                content.push(AgentMessageContent::Thinking { text, signature });
230                            }
231                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
232                                data,
233                            } => {
234                                content.push(AgentMessageContent::RedactedThinking(data));
235                            }
236                        }
237                    }
238
239                    // Convert tool uses
240                    let mut tool_names_by_id = HashMap::default();
241                    for tool_use in msg.tool_uses {
242                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
243                        content.push(AgentMessageContent::ToolUse(
244                            language_model::LanguageModelToolUse {
245                                id: tool_use.id,
246                                name: tool_use.name.into(),
247                                raw_input: serde_json::to_string(&tool_use.input)
248                                    .unwrap_or_default(),
249                                input: tool_use.input,
250                                is_input_complete: true,
251                                thought_signature: None,
252                            },
253                        ));
254                    }
255
256                    // Convert tool results
257                    let mut tool_results = IndexMap::default();
258                    for tool_result in msg.tool_results {
259                        let name = tool_names_by_id
260                            .remove(&tool_result.tool_use_id)
261                            .unwrap_or_else(|| SharedString::from("unknown"));
262                        tool_results.insert(
263                            tool_result.tool_use_id.clone(),
264                            language_model::LanguageModelToolResult {
265                                tool_use_id: tool_result.tool_use_id,
266                                tool_name: name.into(),
267                                is_error: tool_result.is_error,
268                                content: tool_result.content,
269                                output: tool_result.output,
270                            },
271                        );
272                    }
273
274                    if let Some(last_user_message_id) = &last_user_message_id
275                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
276                    {
277                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
278                    }
279
280                    crate::Message::Agent(AgentMessage {
281                        content,
282                        tool_results,
283                        reasoning_details: None,
284                    })
285                }
286                language_model::Role::System => {
287                    // Skip system messages as they're not supported in the new format
288                    continue;
289                }
290            };
291
292            messages.push(message);
293        }
294
295        Ok(Self {
296            title: thread.summary,
297            messages,
298            updated_at: thread.updated_at,
299            detailed_summary: match thread.detailed_summary_state {
300                crate::legacy_thread::DetailedSummaryState::NotGenerated
301                | crate::legacy_thread::DetailedSummaryState::Generating => None,
302                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
303            },
304            initial_project_snapshot: thread.initial_project_snapshot,
305            cumulative_token_usage: thread.cumulative_token_usage,
306            request_token_usage,
307            model: thread.model,
308            profile: thread.profile,
309            imported: false,
310            subagent_context: None,
311            speed: None,
312            thinking_enabled: false,
313            thinking_effort: None,
314            draft_prompt: None,
315            provisional_title: None,
316            ui_scroll_position: None,
317        })
318    }
319}
320
321#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
322pub enum DataType {
323    #[serde(rename = "json")]
324    Json,
325    #[serde(rename = "zstd")]
326    Zstd,
327}
328
329impl Bind for DataType {
330    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
331        let value = match self {
332            DataType::Json => "json",
333            DataType::Zstd => "zstd",
334        };
335        value.bind(statement, start_index)
336    }
337}
338
339impl Column for DataType {
340    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
341        let (value, next_index) = String::column(statement, start_index)?;
342        let data_type = match value.as_str() {
343            "json" => DataType::Json,
344            "zstd" => DataType::Zstd,
345            _ => anyhow::bail!("Unknown data type: {}", value),
346        };
347        Ok((data_type, next_index))
348    }
349}
350
351pub(crate) struct ThreadsDatabase {
352    executor: BackgroundExecutor,
353    connection: Arc<Mutex<Connection>>,
354}
355
356struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
357
358impl Global for GlobalThreadsDatabase {}
359
360impl ThreadsDatabase {
361    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
362        if cx.has_global::<GlobalThreadsDatabase>() {
363            return cx.global::<GlobalThreadsDatabase>().0.clone();
364        }
365        let executor = cx.background_executor().clone();
366        let task = executor
367            .spawn({
368                let executor = executor.clone();
369                async move {
370                    match ThreadsDatabase::new(executor) {
371                        Ok(db) => Ok(Arc::new(db)),
372                        Err(err) => Err(Arc::new(err)),
373                    }
374                }
375            })
376            .shared();
377
378        cx.set_global(GlobalThreadsDatabase(task.clone()));
379        task
380    }
381
382    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
383        let connection = if *ZED_STATELESS {
384            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
385        } else if cfg!(any(feature = "test-support", test)) {
386            // rust stores the name of the test on the current thread.
387            // We use this to automatically create a database that will
388            // be shared within the test (for the test_retrieve_old_thread)
389            // but not with concurrent tests.
390            let thread = std::thread::current();
391            let test_name = thread.name();
392            Connection::open_memory(Some(&format!(
393                "THREAD_FALLBACK_{}",
394                test_name.unwrap_or_default()
395            )))
396        } else {
397            let threads_dir = paths::data_dir().join("threads");
398            std::fs::create_dir_all(&threads_dir)?;
399            let sqlite_path = threads_dir.join("threads.db");
400            Connection::open_file(&sqlite_path.to_string_lossy())
401        };
402
403        connection.exec(indoc! {"
404            CREATE TABLE IF NOT EXISTS threads (
405                id TEXT PRIMARY KEY,
406                summary TEXT NOT NULL,
407                updated_at TEXT NOT NULL,
408                data_type TEXT NOT NULL,
409                data BLOB NOT NULL
410            )
411        "})?()
412        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
413
414        if let Ok(mut s) = connection.exec(indoc! {"
415            ALTER TABLE threads ADD COLUMN parent_id TEXT
416        "})
417        {
418            s().ok();
419        }
420
421        if let Ok(mut s) = connection.exec(indoc! {"
422            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
423            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
424        "})
425        {
426            s().ok();
427        }
428
429        if let Ok(mut s) = connection.exec(indoc! {"
430            ALTER TABLE threads ADD COLUMN created_at TEXT;
431        "})
432        {
433            if s().is_ok() {
434                connection.exec(indoc! {"
435                    UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
436                "})?()?;
437            }
438        }
439
440        let db = Self {
441            executor,
442            connection: Arc::new(Mutex::new(connection)),
443        };
444
445        Ok(db)
446    }
447
448    fn save_thread_sync(
449        connection: &Arc<Mutex<Connection>>,
450        id: acp::SessionId,
451        thread: DbThread,
452        folder_paths: &PathList,
453    ) -> Result<()> {
454        const COMPRESSION_LEVEL: i32 = 3;
455
456        #[derive(Serialize)]
457        struct SerializedThread {
458            #[serde(flatten)]
459            thread: DbThread,
460            version: &'static str,
461        }
462
463        let title = thread.title.to_string();
464        let updated_at = thread.updated_at.to_rfc3339();
465        let parent_id = thread
466            .subagent_context
467            .as_ref()
468            .map(|ctx| ctx.parent_thread_id.0.clone());
469        let serialized_folder_paths = folder_paths.serialize();
470        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
471            if folder_paths.is_empty() {
472                (None, None)
473            } else {
474                (
475                    Some(serialized_folder_paths.paths),
476                    Some(serialized_folder_paths.order),
477                )
478            };
479        let json_data = serde_json::to_string(&SerializedThread {
480            thread,
481            version: DbThread::VERSION,
482        })?;
483
484        let connection = connection.lock();
485
486        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
487        let data_type = DataType::Zstd;
488        let data = compressed;
489
490        // Use the thread's updated_at as created_at for new threads.
491        // This ensures the creation time reflects when the thread was conceptually
492        // created, not when it was saved to the database.
493        let created_at = updated_at.clone();
494
495        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
496            INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
497            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
498            ON CONFLICT(id) DO UPDATE SET
499                parent_id = excluded.parent_id,
500                folder_paths = excluded.folder_paths,
501                folder_paths_order = excluded.folder_paths_order,
502                summary = excluded.summary,
503                updated_at = excluded.updated_at,
504                data_type = excluded.data_type,
505                data = excluded.data
506        "})?;
507
508        insert((
509            id.0,
510            parent_id,
511            folder_paths_str,
512            folder_paths_order_str,
513            title,
514            updated_at,
515            data_type,
516            data,
517            created_at,
518        ))?;
519
520        Ok(())
521    }
522
523    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
524        let connection = self.connection.clone();
525
526        self.executor.spawn(async move {
527            let connection = connection.lock();
528
529            let mut select = connection
530                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
531                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
532            "})?;
533
534            let rows = select(())?;
535            let mut threads = Vec::new();
536
537            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
538                let folder_paths = folder_paths
539                    .map(|paths| {
540                        PathList::deserialize(&util::path_list::SerializedPathList {
541                            paths,
542                            order: folder_paths_order.unwrap_or_default(),
543                        })
544                    })
545                    .unwrap_or_default();
546                let created_at = created_at
547                    .as_deref()
548                    .map(DateTime::parse_from_rfc3339)
549                    .transpose()?
550                    .map(|dt| dt.with_timezone(&Utc));
551
552                threads.push(DbThreadMetadata {
553                    id: acp::SessionId::new(id),
554                    parent_session_id: parent_id.map(acp::SessionId::new),
555                    title: summary.into(),
556                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
557                    created_at,
558                    folder_paths,
559                });
560            }
561
562            Ok(threads)
563        })
564    }
565
566    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
567        let connection = self.connection.clone();
568
569        self.executor.spawn(async move {
570            let connection = connection.lock();
571            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
572                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
573            "})?;
574
575            let rows = select(id.0)?;
576            if let Some((data_type, data)) = rows.into_iter().next() {
577                let json_data = match data_type {
578                    DataType::Zstd => {
579                        let decompressed = zstd::decode_all(&data[..])?;
580                        String::from_utf8(decompressed)?
581                    }
582                    DataType::Json => String::from_utf8(data)?,
583                };
584                let thread = DbThread::from_json(json_data.as_bytes())?;
585                Ok(Some(thread))
586            } else {
587                Ok(None)
588            }
589        })
590    }
591
592    pub fn save_thread(
593        &self,
594        id: acp::SessionId,
595        thread: DbThread,
596        folder_paths: PathList,
597    ) -> Task<Result<()>> {
598        let connection = self.connection.clone();
599
600        self.executor
601            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
602    }
603
604    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
605        let connection = self.connection.clone();
606
607        self.executor.spawn(async move {
608            let connection = connection.lock();
609
610            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
611                DELETE FROM threads WHERE id = ?
612            "})?;
613
614            delete(id.0)?;
615
616            Ok(())
617        })
618    }
619
620    pub fn delete_threads(&self) -> Task<Result<()>> {
621        let connection = self.connection.clone();
622
623        self.executor.spawn(async move {
624            let connection = connection.lock();
625
626            let mut delete = connection.exec_bound::<()>(indoc! {"
627                DELETE FROM threads
628            "})?;
629
630            delete(())?;
631
632            Ok(())
633        })
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640    use chrono::{DateTime, TimeZone, Utc};
641    use collections::HashMap;
642    use gpui::TestAppContext;
643    use std::sync::Arc;
644
645    #[test]
646    fn test_shared_thread_roundtrip() {
647        let original = SharedThread {
648            title: "Test Thread".into(),
649            messages: vec![],
650            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
651            model: None,
652            version: SharedThread::VERSION.to_string(),
653        };
654
655        let bytes = original.to_bytes().expect("Failed to serialize");
656        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
657
658        assert_eq!(restored.title, original.title);
659        assert_eq!(restored.version, original.version);
660        assert_eq!(restored.updated_at, original.updated_at);
661    }
662
663    #[test]
664    fn test_imported_flag_defaults_to_false() {
665        // Simulate deserializing a thread without the imported field (backwards compatibility).
666        let json = r#"{
667            "title": "Old Thread",
668            "messages": [],
669            "updated_at": "2024-01-01T00:00:00Z"
670        }"#;
671
672        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
673
674        assert!(
675            !db_thread.imported,
676            "Legacy threads without imported field should default to false"
677        );
678    }
679
680    fn session_id(value: &str) -> acp::SessionId {
681        acp::SessionId::new(Arc::<str>::from(value))
682    }
683
684    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
685        DbThread {
686            title: title.to_string().into(),
687            messages: Vec::new(),
688            updated_at,
689            detailed_summary: None,
690            initial_project_snapshot: None,
691            cumulative_token_usage: Default::default(),
692            request_token_usage: HashMap::default(),
693            model: None,
694            profile: None,
695            imported: false,
696            subagent_context: None,
697            speed: None,
698            thinking_enabled: false,
699            thinking_effort: None,
700            draft_prompt: None,
701            provisional_title: None,
702            ui_scroll_position: None,
703        }
704    }
705
706    #[gpui::test]
707    async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
708        let database = ThreadsDatabase::new(cx.executor()).unwrap();
709
710        let older_id = session_id("thread-a");
711        let newer_id = session_id("thread-b");
712
713        let older_thread = make_thread(
714            "Thread A",
715            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
716        );
717        let newer_thread = make_thread(
718            "Thread B",
719            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
720        );
721
722        database
723            .save_thread(older_id.clone(), older_thread, PathList::default())
724            .await
725            .unwrap();
726        database
727            .save_thread(newer_id.clone(), newer_thread, PathList::default())
728            .await
729            .unwrap();
730
731        let entries = database.list_threads().await.unwrap();
732        assert_eq!(entries.len(), 2);
733        assert_eq!(entries[0].id, newer_id);
734        assert_eq!(entries[1].id, older_id);
735    }
736
737    #[gpui::test]
738    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
739        let database = ThreadsDatabase::new(cx.executor()).unwrap();
740
741        let thread_id = session_id("thread-a");
742        let original_thread = make_thread(
743            "Thread A",
744            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
745        );
746        let updated_thread = make_thread(
747            "Thread B",
748            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
749        );
750
751        database
752            .save_thread(thread_id.clone(), original_thread, PathList::default())
753            .await
754            .unwrap();
755        database
756            .save_thread(thread_id.clone(), updated_thread, PathList::default())
757            .await
758            .unwrap();
759
760        let entries = database.list_threads().await.unwrap();
761        assert_eq!(entries.len(), 1);
762        assert_eq!(entries[0].id, thread_id);
763        assert_eq!(entries[0].title.as_ref(), "Thread B");
764        assert_eq!(
765            entries[0].updated_at,
766            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
767        );
768        assert!(
769            entries[0].created_at.is_some(),
770            "created_at should be populated"
771        );
772    }
773
774    #[test]
775    fn test_subagent_context_defaults_to_none() {
776        let json = r#"{
777            "title": "Old Thread",
778            "messages": [],
779            "updated_at": "2024-01-01T00:00:00Z"
780        }"#;
781
782        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
783
784        assert!(
785            db_thread.subagent_context.is_none(),
786            "Legacy threads without subagent_context should default to None"
787        );
788    }
789
790    #[test]
791    fn test_draft_prompt_defaults_to_none() {
792        let json = r#"{
793            "title": "Old Thread",
794            "messages": [],
795            "updated_at": "2024-01-01T00:00:00Z"
796        }"#;
797
798        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
799
800        assert!(
801            db_thread.draft_prompt.is_none(),
802            "Legacy threads without draft_prompt field should default to None"
803        );
804    }
805
806    #[gpui::test]
807    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
808        let database = ThreadsDatabase::new(cx.executor()).unwrap();
809
810        let parent_id = session_id("parent-thread");
811        let child_id = session_id("child-thread");
812
813        let mut child_thread = make_thread(
814            "Subagent Thread",
815            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
816        );
817        child_thread.subagent_context = Some(crate::SubagentContext {
818            parent_thread_id: parent_id.clone(),
819            depth: 2,
820        });
821
822        database
823            .save_thread(child_id.clone(), child_thread, PathList::default())
824            .await
825            .unwrap();
826
827        let loaded = database
828            .load_thread(child_id)
829            .await
830            .unwrap()
831            .expect("thread should exist");
832
833        let context = loaded
834            .subagent_context
835            .expect("subagent_context should be restored");
836        assert_eq!(context.parent_thread_id, parent_id);
837        assert_eq!(context.depth, 2);
838    }
839
840    #[gpui::test]
841    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
842        let database = ThreadsDatabase::new(cx.executor()).unwrap();
843
844        let thread_id = session_id("regular-thread");
845        let thread = make_thread(
846            "Regular Thread",
847            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
848        );
849
850        database
851            .save_thread(thread_id.clone(), thread, PathList::default())
852            .await
853            .unwrap();
854
855        let loaded = database
856            .load_thread(thread_id)
857            .await
858            .unwrap()
859            .expect("thread should exist");
860
861        assert!(
862            loaded.subagent_context.is_none(),
863            "Regular threads should have no subagent_context"
864        );
865    }
866
867    #[gpui::test]
868    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
869        let database = ThreadsDatabase::new(cx.executor()).unwrap();
870
871        let thread_id = session_id("folder-thread");
872        let thread = make_thread(
873            "Folder Thread",
874            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
875        );
876
877        let folder_paths = PathList::new(&[
878            std::path::PathBuf::from("/home/user/project-a"),
879            std::path::PathBuf::from("/home/user/project-b"),
880        ]);
881
882        database
883            .save_thread(thread_id.clone(), thread, folder_paths.clone())
884            .await
885            .unwrap();
886
887        let threads = database.list_threads().await.unwrap();
888        assert_eq!(threads.len(), 1);
889        assert_eq!(threads[0].folder_paths, folder_paths);
890    }
891
892    #[gpui::test]
893    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
894        let database = ThreadsDatabase::new(cx.executor()).unwrap();
895
896        let thread_id = session_id("no-folder-thread");
897        let thread = make_thread(
898            "No Folder Thread",
899            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
900        );
901
902        database
903            .save_thread(thread_id.clone(), thread, PathList::default())
904            .await
905            .unwrap();
906
907        let threads = database.list_threads().await.unwrap();
908        assert_eq!(threads.len(), 1);
909        assert!(threads[0].folder_paths.is_empty());
910    }
911
912    #[test]
913    fn test_scroll_position_defaults_to_none() {
914        let json = r#"{
915            "title": "Old Thread",
916            "messages": [],
917            "updated_at": "2024-01-01T00:00:00Z"
918        }"#;
919
920        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
921
922        assert!(
923            db_thread.ui_scroll_position.is_none(),
924            "Legacy threads without scroll_position field should default to None"
925        );
926    }
927
928    #[gpui::test]
929    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
930        let database = ThreadsDatabase::new(cx.executor()).unwrap();
931
932        let thread_id = session_id("thread-with-scroll");
933
934        let mut thread = make_thread(
935            "Thread With Scroll",
936            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
937        );
938        thread.ui_scroll_position = Some(SerializedScrollPosition {
939            item_ix: 42,
940            offset_in_item: 13.5,
941        });
942
943        database
944            .save_thread(thread_id.clone(), thread, PathList::default())
945            .await
946            .unwrap();
947
948        let loaded = database
949            .load_thread(thread_id)
950            .await
951            .unwrap()
952            .expect("thread should exist");
953
954        let scroll = loaded
955            .ui_scroll_position
956            .expect("scroll_position should be restored");
957        assert_eq!(scroll.item_ix, 42);
958        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
959    }
960}