db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use language_model::Speed;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use util::path_list::PathList;
 22use zed_env_vars::ZED_STATELESS;
 23
 24pub type DbMessage = crate::Message;
 25pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 26pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 27
 28#[derive(Debug, Clone, Serialize, Deserialize)]
 29pub struct DbThreadMetadata {
 30    pub id: acp::SessionId,
 31    pub parent_session_id: Option<acp::SessionId>,
 32    #[serde(alias = "summary")]
 33    pub title: SharedString,
 34    pub updated_at: DateTime<Utc>,
 35    /// The workspace folder paths this thread was created against, sorted
 36    /// lexicographically. Used for grouping threads by project in the sidebar.
 37    pub folder_paths: PathList,
 38}
 39
 40#[derive(Debug, Serialize, Deserialize)]
 41pub struct DbThread {
 42    pub title: SharedString,
 43    pub messages: Vec<DbMessage>,
 44    pub updated_at: DateTime<Utc>,
 45    #[serde(default)]
 46    pub detailed_summary: Option<SharedString>,
 47    #[serde(default)]
 48    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 49    #[serde(default)]
 50    pub cumulative_token_usage: language_model::TokenUsage,
 51    #[serde(default)]
 52    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 53    #[serde(default)]
 54    pub model: Option<DbLanguageModel>,
 55    #[serde(default)]
 56    pub profile: Option<AgentProfileId>,
 57    #[serde(default)]
 58    pub imported: bool,
 59    #[serde(default)]
 60    pub subagent_context: Option<crate::SubagentContext>,
 61    #[serde(default)]
 62    pub speed: Option<Speed>,
 63    #[serde(default)]
 64    pub thinking_enabled: bool,
 65    #[serde(default)]
 66    pub thinking_effort: Option<String>,
 67    #[serde(default)]
 68    pub draft_prompt: Option<Vec<acp::ContentBlock>>,
 69    #[serde(default)]
 70    pub ui_scroll_position: Option<SerializedScrollPosition>,
 71}
 72
 73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 74pub struct SerializedScrollPosition {
 75    pub item_ix: usize,
 76    pub offset_in_item: f32,
 77}
 78
 79#[derive(Debug, Clone, Serialize, Deserialize)]
 80pub struct SharedThread {
 81    pub title: SharedString,
 82    pub messages: Vec<DbMessage>,
 83    pub updated_at: DateTime<Utc>,
 84    #[serde(default)]
 85    pub model: Option<DbLanguageModel>,
 86    pub version: String,
 87}
 88
 89impl SharedThread {
 90    pub const VERSION: &'static str = "1.0.0";
 91
 92    pub fn from_db_thread(thread: &DbThread) -> Self {
 93        Self {
 94            title: thread.title.clone(),
 95            messages: thread.messages.clone(),
 96            updated_at: thread.updated_at,
 97            model: thread.model.clone(),
 98            version: Self::VERSION.to_string(),
 99        }
100    }
101
102    pub fn to_db_thread(self) -> DbThread {
103        DbThread {
104            title: format!("🔗 {}", self.title).into(),
105            messages: self.messages,
106            updated_at: self.updated_at,
107            detailed_summary: None,
108            initial_project_snapshot: None,
109            cumulative_token_usage: Default::default(),
110            request_token_usage: Default::default(),
111            model: self.model,
112            profile: None,
113            imported: true,
114            subagent_context: None,
115            speed: None,
116            thinking_enabled: false,
117            thinking_effort: None,
118            draft_prompt: None,
119            ui_scroll_position: None,
120        }
121    }
122
123    pub fn to_bytes(&self) -> Result<Vec<u8>> {
124        const COMPRESSION_LEVEL: i32 = 3;
125        let json = serde_json::to_vec(self)?;
126        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
127        Ok(compressed)
128    }
129
130    pub fn from_bytes(data: &[u8]) -> Result<Self> {
131        let decompressed = zstd::decode_all(data)?;
132        Ok(serde_json::from_slice(&decompressed)?)
133    }
134}
135
136impl DbThread {
137    pub const VERSION: &'static str = "0.3.0";
138
139    pub fn from_json(json: &[u8]) -> Result<Self> {
140        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
141        match saved_thread_json.get("version") {
142            Some(serde_json::Value::String(version)) => match version.as_str() {
143                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
144                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
145                    json,
146                )?),
147            },
148            _ => {
149                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
150            }
151        }
152    }
153
154    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
155        let mut messages = Vec::new();
156        let mut request_token_usage = HashMap::default();
157
158        let mut last_user_message_id = None;
159        for (ix, msg) in thread.messages.into_iter().enumerate() {
160            let message = match msg.role {
161                language_model::Role::User => {
162                    let mut content = Vec::new();
163
164                    // Convert segments to content
165                    for segment in msg.segments {
166                        match segment {
167                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
168                                content.push(UserMessageContent::Text(text));
169                            }
170                            crate::legacy_thread::SerializedMessageSegment::Thinking {
171                                text,
172                                ..
173                            } => {
174                                // User messages don't have thinking segments, but handle gracefully
175                                content.push(UserMessageContent::Text(text));
176                            }
177                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
178                                ..
179                            } => {
180                                // User messages don't have redacted thinking, skip.
181                            }
182                        }
183                    }
184
185                    // If no content was added, add context as text if available
186                    if content.is_empty() && !msg.context.is_empty() {
187                        content.push(UserMessageContent::Text(msg.context));
188                    }
189
190                    let id = UserMessageId::new();
191                    last_user_message_id = Some(id.clone());
192
193                    crate::Message::User(UserMessage {
194                        // MessageId from old format can't be meaningfully converted, so generate a new one
195                        id,
196                        content,
197                    })
198                }
199                language_model::Role::Assistant => {
200                    let mut content = Vec::new();
201
202                    // Convert segments to content
203                    for segment in msg.segments {
204                        match segment {
205                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
206                                content.push(AgentMessageContent::Text(text));
207                            }
208                            crate::legacy_thread::SerializedMessageSegment::Thinking {
209                                text,
210                                signature,
211                            } => {
212                                content.push(AgentMessageContent::Thinking { text, signature });
213                            }
214                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
215                                data,
216                            } => {
217                                content.push(AgentMessageContent::RedactedThinking(data));
218                            }
219                        }
220                    }
221
222                    // Convert tool uses
223                    let mut tool_names_by_id = HashMap::default();
224                    for tool_use in msg.tool_uses {
225                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
226                        content.push(AgentMessageContent::ToolUse(
227                            language_model::LanguageModelToolUse {
228                                id: tool_use.id,
229                                name: tool_use.name.into(),
230                                raw_input: serde_json::to_string(&tool_use.input)
231                                    .unwrap_or_default(),
232                                input: tool_use.input,
233                                is_input_complete: true,
234                                thought_signature: None,
235                            },
236                        ));
237                    }
238
239                    // Convert tool results
240                    let mut tool_results = IndexMap::default();
241                    for tool_result in msg.tool_results {
242                        let name = tool_names_by_id
243                            .remove(&tool_result.tool_use_id)
244                            .unwrap_or_else(|| SharedString::from("unknown"));
245                        tool_results.insert(
246                            tool_result.tool_use_id.clone(),
247                            language_model::LanguageModelToolResult {
248                                tool_use_id: tool_result.tool_use_id,
249                                tool_name: name.into(),
250                                is_error: tool_result.is_error,
251                                content: tool_result.content,
252                                output: tool_result.output,
253                            },
254                        );
255                    }
256
257                    if let Some(last_user_message_id) = &last_user_message_id
258                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
259                    {
260                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
261                    }
262
263                    crate::Message::Agent(AgentMessage {
264                        content,
265                        tool_results,
266                        reasoning_details: None,
267                    })
268                }
269                language_model::Role::System => {
270                    // Skip system messages as they're not supported in the new format
271                    continue;
272                }
273            };
274
275            messages.push(message);
276        }
277
278        Ok(Self {
279            title: thread.summary,
280            messages,
281            updated_at: thread.updated_at,
282            detailed_summary: match thread.detailed_summary_state {
283                crate::legacy_thread::DetailedSummaryState::NotGenerated
284                | crate::legacy_thread::DetailedSummaryState::Generating => None,
285                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
286            },
287            initial_project_snapshot: thread.initial_project_snapshot,
288            cumulative_token_usage: thread.cumulative_token_usage,
289            request_token_usage,
290            model: thread.model,
291            profile: thread.profile,
292            imported: false,
293            subagent_context: None,
294            speed: None,
295            thinking_enabled: false,
296            thinking_effort: None,
297            draft_prompt: None,
298            ui_scroll_position: None,
299        })
300    }
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
304pub enum DataType {
305    #[serde(rename = "json")]
306    Json,
307    #[serde(rename = "zstd")]
308    Zstd,
309}
310
311impl Bind for DataType {
312    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
313        let value = match self {
314            DataType::Json => "json",
315            DataType::Zstd => "zstd",
316        };
317        value.bind(statement, start_index)
318    }
319}
320
321impl Column for DataType {
322    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
323        let (value, next_index) = String::column(statement, start_index)?;
324        let data_type = match value.as_str() {
325            "json" => DataType::Json,
326            "zstd" => DataType::Zstd,
327            _ => anyhow::bail!("Unknown data type: {}", value),
328        };
329        Ok((data_type, next_index))
330    }
331}
332
333pub(crate) struct ThreadsDatabase {
334    executor: BackgroundExecutor,
335    connection: Arc<Mutex<Connection>>,
336}
337
338struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
339
340impl Global for GlobalThreadsDatabase {}
341
342impl ThreadsDatabase {
343    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
344        if cx.has_global::<GlobalThreadsDatabase>() {
345            return cx.global::<GlobalThreadsDatabase>().0.clone();
346        }
347        let executor = cx.background_executor().clone();
348        let task = executor
349            .spawn({
350                let executor = executor.clone();
351                async move {
352                    match ThreadsDatabase::new(executor) {
353                        Ok(db) => Ok(Arc::new(db)),
354                        Err(err) => Err(Arc::new(err)),
355                    }
356                }
357            })
358            .shared();
359
360        cx.set_global(GlobalThreadsDatabase(task.clone()));
361        task
362    }
363
364    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
365        let connection = if *ZED_STATELESS {
366            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
367        } else if cfg!(any(feature = "test-support", test)) {
368            // rust stores the name of the test on the current thread.
369            // We use this to automatically create a database that will
370            // be shared within the test (for the test_retrieve_old_thread)
371            // but not with concurrent tests.
372            let thread = std::thread::current();
373            let test_name = thread.name();
374            Connection::open_memory(Some(&format!(
375                "THREAD_FALLBACK_{}",
376                test_name.unwrap_or_default()
377            )))
378        } else {
379            let threads_dir = paths::data_dir().join("threads");
380            std::fs::create_dir_all(&threads_dir)?;
381            let sqlite_path = threads_dir.join("threads.db");
382            Connection::open_file(&sqlite_path.to_string_lossy())
383        };
384
385        connection.exec("PRAGMA journal_mode=WAL;")?()?;
386        connection.exec("PRAGMA busy_timeout=1000;")?()?;
387
388        connection.exec(indoc! {"
389            CREATE TABLE IF NOT EXISTS threads (
390                id TEXT PRIMARY KEY,
391                summary TEXT NOT NULL,
392                updated_at TEXT NOT NULL,
393                data_type TEXT NOT NULL,
394                data BLOB NOT NULL
395            )
396        "})?()
397        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
398
399        if let Ok(mut s) = connection.exec(indoc! {"
400            ALTER TABLE threads ADD COLUMN parent_id TEXT
401        "})
402        {
403            s().ok();
404        }
405
406        if let Ok(mut s) = connection.exec(indoc! {"
407            ALTER TABLE threads ADD COLUMN folder_paths TEXT;
408            ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
409        "})
410        {
411            s().ok();
412        }
413
414        let db = Self {
415            executor,
416            connection: Arc::new(Mutex::new(connection)),
417        };
418
419        Ok(db)
420    }
421
422    fn save_thread_sync(
423        connection: &Arc<Mutex<Connection>>,
424        id: acp::SessionId,
425        thread: DbThread,
426        folder_paths: &PathList,
427    ) -> Result<()> {
428        const COMPRESSION_LEVEL: i32 = 3;
429
430        #[derive(Serialize)]
431        struct SerializedThread {
432            #[serde(flatten)]
433            thread: DbThread,
434            version: &'static str,
435        }
436
437        let title = thread.title.to_string();
438        let updated_at = thread.updated_at.to_rfc3339();
439        let parent_id = thread
440            .subagent_context
441            .as_ref()
442            .map(|ctx| ctx.parent_thread_id.0.clone());
443        let serialized_folder_paths = folder_paths.serialize();
444        let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
445            if folder_paths.is_empty() {
446                (None, None)
447            } else {
448                (
449                    Some(serialized_folder_paths.paths),
450                    Some(serialized_folder_paths.order),
451                )
452            };
453        let json_data = serde_json::to_string(&SerializedThread {
454            thread,
455            version: DbThread::VERSION,
456        })?;
457
458        let connection = connection.lock();
459
460        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
461        let data_type = DataType::Zstd;
462        let data = compressed;
463
464        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
465            INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
466        "})?;
467
468        insert((
469            id.0,
470            parent_id,
471            folder_paths_str,
472            folder_paths_order_str,
473            title,
474            updated_at,
475            data_type,
476            data,
477        ))?;
478
479        Ok(())
480    }
481
482    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
483        let connection = self.connection.clone();
484
485        self.executor.spawn(async move {
486            let connection = connection.lock();
487
488            let mut select = connection
489                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
490                SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
491            "})?;
492
493            let rows = select(())?;
494            let mut threads = Vec::new();
495
496            for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
497                let folder_paths = folder_paths
498                    .map(|paths| {
499                        PathList::deserialize(&util::path_list::SerializedPathList {
500                            paths,
501                            order: folder_paths_order.unwrap_or_default(),
502                        })
503                    })
504                    .unwrap_or_default();
505                threads.push(DbThreadMetadata {
506                    id: acp::SessionId::new(id),
507                    parent_session_id: parent_id.map(acp::SessionId::new),
508                    title: summary.into(),
509                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
510                    folder_paths,
511                });
512            }
513
514            Ok(threads)
515        })
516    }
517
518    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
519        let connection = self.connection.clone();
520
521        self.executor.spawn(async move {
522            let connection = connection.lock();
523            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
524                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
525            "})?;
526
527            let rows = select(id.0)?;
528            if let Some((data_type, data)) = rows.into_iter().next() {
529                let json_data = match data_type {
530                    DataType::Zstd => {
531                        let decompressed = zstd::decode_all(&data[..])?;
532                        String::from_utf8(decompressed)?
533                    }
534                    DataType::Json => String::from_utf8(data)?,
535                };
536                let thread = DbThread::from_json(json_data.as_bytes())?;
537                Ok(Some(thread))
538            } else {
539                Ok(None)
540            }
541        })
542    }
543
544    pub fn save_thread(
545        &self,
546        id: acp::SessionId,
547        thread: DbThread,
548        folder_paths: PathList,
549    ) -> Task<Result<()>> {
550        let connection = self.connection.clone();
551
552        self.executor
553            .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
554    }
555
556    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
557        let connection = self.connection.clone();
558
559        self.executor.spawn(async move {
560            let connection = connection.lock();
561
562            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
563                DELETE FROM threads WHERE id = ?
564            "})?;
565
566            delete(id.0)?;
567
568            Ok(())
569        })
570    }
571
572    pub fn delete_threads(&self) -> Task<Result<()>> {
573        let connection = self.connection.clone();
574
575        self.executor.spawn(async move {
576            let connection = connection.lock();
577
578            let mut delete = connection.exec_bound::<()>(indoc! {"
579                DELETE FROM threads
580            "})?;
581
582            delete(())?;
583
584            Ok(())
585        })
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use chrono::{DateTime, TimeZone, Utc};
593    use collections::HashMap;
594    use gpui::TestAppContext;
595    use std::sync::Arc;
596
597    #[test]
598    fn test_shared_thread_roundtrip() {
599        let original = SharedThread {
600            title: "Test Thread".into(),
601            messages: vec![],
602            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
603            model: None,
604            version: SharedThread::VERSION.to_string(),
605        };
606
607        let bytes = original.to_bytes().expect("Failed to serialize");
608        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
609
610        assert_eq!(restored.title, original.title);
611        assert_eq!(restored.version, original.version);
612        assert_eq!(restored.updated_at, original.updated_at);
613    }
614
615    #[test]
616    fn test_imported_flag_defaults_to_false() {
617        // Simulate deserializing a thread without the imported field (backwards compatibility).
618        let json = r#"{
619            "title": "Old Thread",
620            "messages": [],
621            "updated_at": "2024-01-01T00:00:00Z"
622        }"#;
623
624        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
625
626        assert!(
627            !db_thread.imported,
628            "Legacy threads without imported field should default to false"
629        );
630    }
631
632    fn session_id(value: &str) -> acp::SessionId {
633        acp::SessionId::new(Arc::<str>::from(value))
634    }
635
636    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
637        DbThread {
638            title: title.to_string().into(),
639            messages: Vec::new(),
640            updated_at,
641            detailed_summary: None,
642            initial_project_snapshot: None,
643            cumulative_token_usage: Default::default(),
644            request_token_usage: HashMap::default(),
645            model: None,
646            profile: None,
647            imported: false,
648            subagent_context: None,
649            speed: None,
650            thinking_enabled: false,
651            thinking_effort: None,
652            draft_prompt: None,
653            ui_scroll_position: None,
654        }
655    }
656
657    #[gpui::test]
658    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
659        let database = ThreadsDatabase::new(cx.executor()).unwrap();
660
661        let older_id = session_id("thread-a");
662        let newer_id = session_id("thread-b");
663
664        let older_thread = make_thread(
665            "Thread A",
666            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
667        );
668        let newer_thread = make_thread(
669            "Thread B",
670            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
671        );
672
673        database
674            .save_thread(older_id.clone(), older_thread, PathList::default())
675            .await
676            .unwrap();
677        database
678            .save_thread(newer_id.clone(), newer_thread, PathList::default())
679            .await
680            .unwrap();
681
682        let entries = database.list_threads().await.unwrap();
683        assert_eq!(entries.len(), 2);
684        assert_eq!(entries[0].id, newer_id);
685        assert_eq!(entries[1].id, older_id);
686    }
687
688    #[gpui::test]
689    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
690        let database = ThreadsDatabase::new(cx.executor()).unwrap();
691
692        let thread_id = session_id("thread-a");
693        let original_thread = make_thread(
694            "Thread A",
695            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
696        );
697        let updated_thread = make_thread(
698            "Thread B",
699            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
700        );
701
702        database
703            .save_thread(thread_id.clone(), original_thread, PathList::default())
704            .await
705            .unwrap();
706        database
707            .save_thread(thread_id.clone(), updated_thread, PathList::default())
708            .await
709            .unwrap();
710
711        let entries = database.list_threads().await.unwrap();
712        assert_eq!(entries.len(), 1);
713        assert_eq!(entries[0].id, thread_id);
714        assert_eq!(entries[0].title.as_ref(), "Thread B");
715        assert_eq!(
716            entries[0].updated_at,
717            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
718        );
719    }
720
721    #[test]
722    fn test_subagent_context_defaults_to_none() {
723        let json = r#"{
724            "title": "Old Thread",
725            "messages": [],
726            "updated_at": "2024-01-01T00:00:00Z"
727        }"#;
728
729        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
730
731        assert!(
732            db_thread.subagent_context.is_none(),
733            "Legacy threads without subagent_context should default to None"
734        );
735    }
736
737    #[test]
738    fn test_draft_prompt_defaults_to_none() {
739        let json = r#"{
740            "title": "Old Thread",
741            "messages": [],
742            "updated_at": "2024-01-01T00:00:00Z"
743        }"#;
744
745        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
746
747        assert!(
748            db_thread.draft_prompt.is_none(),
749            "Legacy threads without draft_prompt field should default to None"
750        );
751    }
752
753    #[gpui::test]
754    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
755        let database = ThreadsDatabase::new(cx.executor()).unwrap();
756
757        let parent_id = session_id("parent-thread");
758        let child_id = session_id("child-thread");
759
760        let mut child_thread = make_thread(
761            "Subagent Thread",
762            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
763        );
764        child_thread.subagent_context = Some(crate::SubagentContext {
765            parent_thread_id: parent_id.clone(),
766            depth: 2,
767        });
768
769        database
770            .save_thread(child_id.clone(), child_thread, PathList::default())
771            .await
772            .unwrap();
773
774        let loaded = database
775            .load_thread(child_id)
776            .await
777            .unwrap()
778            .expect("thread should exist");
779
780        let context = loaded
781            .subagent_context
782            .expect("subagent_context should be restored");
783        assert_eq!(context.parent_thread_id, parent_id);
784        assert_eq!(context.depth, 2);
785    }
786
787    #[gpui::test]
788    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
789        let database = ThreadsDatabase::new(cx.executor()).unwrap();
790
791        let thread_id = session_id("regular-thread");
792        let thread = make_thread(
793            "Regular Thread",
794            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
795        );
796
797        database
798            .save_thread(thread_id.clone(), thread, PathList::default())
799            .await
800            .unwrap();
801
802        let loaded = database
803            .load_thread(thread_id)
804            .await
805            .unwrap()
806            .expect("thread should exist");
807
808        assert!(
809            loaded.subagent_context.is_none(),
810            "Regular threads should have no subagent_context"
811        );
812    }
813
814    #[gpui::test]
815    async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
816        let database = ThreadsDatabase::new(cx.executor()).unwrap();
817
818        let thread_id = session_id("folder-thread");
819        let thread = make_thread(
820            "Folder Thread",
821            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
822        );
823
824        let folder_paths = PathList::new(&[
825            std::path::PathBuf::from("/home/user/project-a"),
826            std::path::PathBuf::from("/home/user/project-b"),
827        ]);
828
829        database
830            .save_thread(thread_id.clone(), thread, folder_paths.clone())
831            .await
832            .unwrap();
833
834        let threads = database.list_threads().await.unwrap();
835        assert_eq!(threads.len(), 1);
836        assert_eq!(threads[0].folder_paths, folder_paths);
837    }
838
839    #[gpui::test]
840    async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
841        let database = ThreadsDatabase::new(cx.executor()).unwrap();
842
843        let thread_id = session_id("no-folder-thread");
844        let thread = make_thread(
845            "No Folder Thread",
846            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
847        );
848
849        database
850            .save_thread(thread_id.clone(), thread, PathList::default())
851            .await
852            .unwrap();
853
854        let threads = database.list_threads().await.unwrap();
855        assert_eq!(threads.len(), 1);
856        assert!(threads[0].folder_paths.is_empty());
857    }
858
859    #[test]
860    fn test_scroll_position_defaults_to_none() {
861        let json = r#"{
862            "title": "Old Thread",
863            "messages": [],
864            "updated_at": "2024-01-01T00:00:00Z"
865        }"#;
866
867        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
868
869        assert!(
870            db_thread.ui_scroll_position.is_none(),
871            "Legacy threads without scroll_position field should default to None"
872        );
873    }
874
875    #[gpui::test]
876    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
877        let database = ThreadsDatabase::new(cx.executor()).unwrap();
878
879        let thread_id = session_id("thread-with-scroll");
880
881        let mut thread = make_thread(
882            "Thread With Scroll",
883            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
884        );
885        thread.ui_scroll_position = Some(SerializedScrollPosition {
886            item_ix: 42,
887            offset_in_item: 13.5,
888        });
889
890        database
891            .save_thread(thread_id.clone(), thread, PathList::default())
892            .await
893            .unwrap();
894
895        let loaded = database
896            .load_thread(thread_id)
897            .await
898            .unwrap()
899            .expect("thread should exist");
900
901        let scroll = loaded
902            .ui_scroll_position
903            .expect("scroll_position should be restored");
904        assert_eq!(scroll.item_ix, 42);
905        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
906    }
907}