db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26/// Metadata about the git worktree associated with an agent thread.
 27#[derive(Debug, Clone, Serialize, Deserialize)]
 28pub struct AgentGitWorktreeInfo {
 29    /// The branch name in the git worktree.
 30    pub branch: String,
 31    /// Absolute path to the git worktree on disk.
 32    pub worktree_path: std::path::PathBuf,
 33    /// The base branch/commit the worktree was created from.
 34    pub base_ref: String,
 35}
 36
 37#[derive(Debug, Clone, Serialize, Deserialize)]
 38pub struct DbThreadMetadata {
 39    pub id: acp::SessionId,
 40    pub parent_session_id: Option<acp::SessionId>,
 41    #[serde(alias = "summary")]
 42    pub title: SharedString,
 43    pub updated_at: DateTime<Utc>,
 44    /// Denormalized from `DbThread::git_worktree_info.branch` for efficient
 45    /// listing without decompressing thread data. The blob is the source of
 46    /// truth; this column is populated on save for query convenience.
 47    pub worktree_branch: Option<String>,
 48}
 49
 50#[derive(Debug, Serialize, Deserialize)]
 51pub struct DbThread {
 52    pub title: SharedString,
 53    pub messages: Vec<DbMessage>,
 54    pub updated_at: DateTime<Utc>,
 55    #[serde(default)]
 56    pub detailed_summary: Option<SharedString>,
 57    #[serde(default)]
 58    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 59    #[serde(default)]
 60    pub cumulative_token_usage: language_model::TokenUsage,
 61    #[serde(default)]
 62    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 63    #[serde(default)]
 64    pub model: Option<DbLanguageModel>,
 65    #[serde(default)]
 66    pub profile: Option<AgentProfileId>,
 67    #[serde(default)]
 68    pub imported: bool,
 69    #[serde(default)]
 70    pub subagent_context: Option<crate::SubagentContext>,
 71    #[serde(default)]
 72    pub git_worktree_info: Option<AgentGitWorktreeInfo>,
 73}
 74
 75#[derive(Debug, Clone, Serialize, Deserialize)]
 76pub struct SharedThread {
 77    pub title: SharedString,
 78    pub messages: Vec<DbMessage>,
 79    pub updated_at: DateTime<Utc>,
 80    #[serde(default)]
 81    pub model: Option<DbLanguageModel>,
 82    pub version: String,
 83}
 84
 85impl SharedThread {
 86    pub const VERSION: &'static str = "1.0.0";
 87
 88    pub fn from_db_thread(thread: &DbThread) -> Self {
 89        Self {
 90            title: thread.title.clone(),
 91            messages: thread.messages.clone(),
 92            updated_at: thread.updated_at,
 93            model: thread.model.clone(),
 94            version: Self::VERSION.to_string(),
 95        }
 96    }
 97
 98    pub fn to_db_thread(self) -> DbThread {
 99        DbThread {
100            title: format!("🔗 {}", self.title).into(),
101            messages: self.messages,
102            updated_at: self.updated_at,
103            detailed_summary: None,
104            initial_project_snapshot: None,
105            cumulative_token_usage: Default::default(),
106            request_token_usage: Default::default(),
107            model: self.model,
108            profile: None,
109            imported: true,
110            subagent_context: None,
111            git_worktree_info: None,
112        }
113    }
114
115    pub fn to_bytes(&self) -> Result<Vec<u8>> {
116        const COMPRESSION_LEVEL: i32 = 3;
117        let json = serde_json::to_vec(self)?;
118        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
119        Ok(compressed)
120    }
121
122    pub fn from_bytes(data: &[u8]) -> Result<Self> {
123        let decompressed = zstd::decode_all(data)?;
124        Ok(serde_json::from_slice(&decompressed)?)
125    }
126}
127
128impl DbThread {
129    pub const VERSION: &'static str = "0.3.0";
130
131    pub fn from_json(json: &[u8]) -> Result<Self> {
132        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
133        match saved_thread_json.get("version") {
134            Some(serde_json::Value::String(version)) => match version.as_str() {
135                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
136                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
137                    json,
138                )?),
139            },
140            _ => {
141                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
142            }
143        }
144    }
145
146    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
147        let mut messages = Vec::new();
148        let mut request_token_usage = HashMap::default();
149
150        let mut last_user_message_id = None;
151        for (ix, msg) in thread.messages.into_iter().enumerate() {
152            let message = match msg.role {
153                language_model::Role::User => {
154                    let mut content = Vec::new();
155
156                    // Convert segments to content
157                    for segment in msg.segments {
158                        match segment {
159                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
160                                content.push(UserMessageContent::Text(text));
161                            }
162                            crate::legacy_thread::SerializedMessageSegment::Thinking {
163                                text,
164                                ..
165                            } => {
166                                // User messages don't have thinking segments, but handle gracefully
167                                content.push(UserMessageContent::Text(text));
168                            }
169                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
170                                ..
171                            } => {
172                                // User messages don't have redacted thinking, skip.
173                            }
174                        }
175                    }
176
177                    // If no content was added, add context as text if available
178                    if content.is_empty() && !msg.context.is_empty() {
179                        content.push(UserMessageContent::Text(msg.context));
180                    }
181
182                    let id = UserMessageId::new();
183                    last_user_message_id = Some(id.clone());
184
185                    crate::Message::User(UserMessage {
186                        // MessageId from old format can't be meaningfully converted, so generate a new one
187                        id,
188                        content,
189                    })
190                }
191                language_model::Role::Assistant => {
192                    let mut content = Vec::new();
193
194                    // Convert segments to content
195                    for segment in msg.segments {
196                        match segment {
197                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
198                                content.push(AgentMessageContent::Text(text));
199                            }
200                            crate::legacy_thread::SerializedMessageSegment::Thinking {
201                                text,
202                                signature,
203                            } => {
204                                content.push(AgentMessageContent::Thinking { text, signature });
205                            }
206                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
207                                data,
208                            } => {
209                                content.push(AgentMessageContent::RedactedThinking(data));
210                            }
211                        }
212                    }
213
214                    // Convert tool uses
215                    let mut tool_names_by_id = HashMap::default();
216                    for tool_use in msg.tool_uses {
217                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
218                        content.push(AgentMessageContent::ToolUse(
219                            language_model::LanguageModelToolUse {
220                                id: tool_use.id,
221                                name: tool_use.name.into(),
222                                raw_input: serde_json::to_string(&tool_use.input)
223                                    .unwrap_or_default(),
224                                input: tool_use.input,
225                                is_input_complete: true,
226                                thought_signature: None,
227                            },
228                        ));
229                    }
230
231                    // Convert tool results
232                    let mut tool_results = IndexMap::default();
233                    for tool_result in msg.tool_results {
234                        let name = tool_names_by_id
235                            .remove(&tool_result.tool_use_id)
236                            .unwrap_or_else(|| SharedString::from("unknown"));
237                        tool_results.insert(
238                            tool_result.tool_use_id.clone(),
239                            language_model::LanguageModelToolResult {
240                                tool_use_id: tool_result.tool_use_id,
241                                tool_name: name.into(),
242                                is_error: tool_result.is_error,
243                                content: tool_result.content,
244                                output: tool_result.output,
245                            },
246                        );
247                    }
248
249                    if let Some(last_user_message_id) = &last_user_message_id
250                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
251                    {
252                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
253                    }
254
255                    crate::Message::Agent(AgentMessage {
256                        content,
257                        tool_results,
258                        reasoning_details: None,
259                    })
260                }
261                language_model::Role::System => {
262                    // Skip system messages as they're not supported in the new format
263                    continue;
264                }
265            };
266
267            messages.push(message);
268        }
269
270        Ok(Self {
271            title: thread.summary,
272            messages,
273            updated_at: thread.updated_at,
274            detailed_summary: match thread.detailed_summary_state {
275                crate::legacy_thread::DetailedSummaryState::NotGenerated
276                | crate::legacy_thread::DetailedSummaryState::Generating => None,
277                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
278            },
279            initial_project_snapshot: thread.initial_project_snapshot,
280            cumulative_token_usage: thread.cumulative_token_usage,
281            request_token_usage,
282            model: thread.model,
283            profile: thread.profile,
284            imported: false,
285            subagent_context: None,
286            git_worktree_info: None,
287        })
288    }
289}
290
291#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292pub enum DataType {
293    #[serde(rename = "json")]
294    Json,
295    #[serde(rename = "zstd")]
296    Zstd,
297}
298
299impl Bind for DataType {
300    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
301        let value = match self {
302            DataType::Json => "json",
303            DataType::Zstd => "zstd",
304        };
305        value.bind(statement, start_index)
306    }
307}
308
309impl Column for DataType {
310    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
311        let (value, next_index) = String::column(statement, start_index)?;
312        let data_type = match value.as_str() {
313            "json" => DataType::Json,
314            "zstd" => DataType::Zstd,
315            _ => anyhow::bail!("Unknown data type: {}", value),
316        };
317        Ok((data_type, next_index))
318    }
319}
320
321pub(crate) struct ThreadsDatabase {
322    executor: BackgroundExecutor,
323    connection: Arc<Mutex<Connection>>,
324}
325
326struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
327
328impl Global for GlobalThreadsDatabase {}
329
330impl ThreadsDatabase {
331    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
332        if cx.has_global::<GlobalThreadsDatabase>() {
333            return cx.global::<GlobalThreadsDatabase>().0.clone();
334        }
335        let executor = cx.background_executor().clone();
336        let task = executor
337            .spawn({
338                let executor = executor.clone();
339                async move {
340                    match ThreadsDatabase::new(executor) {
341                        Ok(db) => Ok(Arc::new(db)),
342                        Err(err) => Err(Arc::new(err)),
343                    }
344                }
345            })
346            .shared();
347
348        cx.set_global(GlobalThreadsDatabase(task.clone()));
349        task
350    }
351
352    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
353        let connection = if *ZED_STATELESS {
354            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
355        } else if cfg!(any(feature = "test-support", test)) {
356            // rust stores the name of the test on the current thread.
357            // We use this to automatically create a database that will
358            // be shared within the test (for the test_retrieve_old_thread)
359            // but not with concurrent tests.
360            let thread = std::thread::current();
361            let test_name = thread.name();
362            Connection::open_memory(Some(&format!(
363                "THREAD_FALLBACK_{}",
364                test_name.unwrap_or_default()
365            )))
366        } else {
367            let threads_dir = paths::data_dir().join("threads");
368            std::fs::create_dir_all(&threads_dir)?;
369            let sqlite_path = threads_dir.join("threads.db");
370            Connection::open_file(&sqlite_path.to_string_lossy())
371        };
372
373        connection.exec(indoc! {"
374            CREATE TABLE IF NOT EXISTS threads (
375                id TEXT PRIMARY KEY,
376                summary TEXT NOT NULL,
377                updated_at TEXT NOT NULL,
378                data_type TEXT NOT NULL,
379                data BLOB NOT NULL
380            )
381        "})?()
382        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
383
384        if let Ok(mut s) = connection.exec(indoc! {"
385            ALTER TABLE threads ADD COLUMN parent_id TEXT
386        "})
387        {
388            s().ok();
389        }
390
391        if let Ok(mut s) = connection.exec(indoc! {"
392            ALTER TABLE threads ADD COLUMN worktree_branch TEXT
393        "})
394        {
395            s().ok();
396        }
397
398        let db = Self {
399            executor,
400            connection: Arc::new(Mutex::new(connection)),
401        };
402
403        Ok(db)
404    }
405
406    fn save_thread_sync(
407        connection: &Arc<Mutex<Connection>>,
408        id: acp::SessionId,
409        thread: DbThread,
410    ) -> Result<()> {
411        const COMPRESSION_LEVEL: i32 = 3;
412
413        #[derive(Serialize)]
414        struct SerializedThread {
415            #[serde(flatten)]
416            thread: DbThread,
417            version: &'static str,
418        }
419
420        let title = thread.title.to_string();
421        let updated_at = thread.updated_at.to_rfc3339();
422        let parent_id = thread
423            .subagent_context
424            .as_ref()
425            .map(|ctx| ctx.parent_thread_id.0.clone());
426        let worktree_branch = thread
427            .git_worktree_info
428            .as_ref()
429            .map(|info| info.branch.clone());
430        let json_data = serde_json::to_string(&SerializedThread {
431            thread,
432            version: DbThread::VERSION,
433        })?;
434
435        let connection = connection.lock();
436
437        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
438        let data_type = DataType::Zstd;
439        let data = compressed;
440
441        let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
442            INSERT OR REPLACE INTO threads (id, parent_id, worktree_branch, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?)
443        "})?;
444
445        insert((
446            id.0,
447            parent_id,
448            worktree_branch,
449            title,
450            updated_at,
451            data_type,
452            data,
453        ))?;
454
455        Ok(())
456    }
457
458    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
459        let connection = self.connection.clone();
460
461        self.executor.spawn(async move {
462            let connection = connection.lock();
463
464            let mut select = connection
465                .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, String, String)>(indoc! {"
466                SELECT id, parent_id, worktree_branch, summary, updated_at FROM threads ORDER BY updated_at DESC
467            "})?;
468
469            let rows = select(())?;
470            let mut threads = Vec::new();
471
472            for (id, parent_id, worktree_branch, summary, updated_at) in rows {
473                threads.push(DbThreadMetadata {
474                    id: acp::SessionId::new(id),
475                    parent_session_id: parent_id.map(acp::SessionId::new),
476                    title: summary.into(),
477                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
478                    worktree_branch,
479                });
480            }
481
482            Ok(threads)
483        })
484    }
485
486    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
487        let connection = self.connection.clone();
488
489        self.executor.spawn(async move {
490            let connection = connection.lock();
491            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
492                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
493            "})?;
494
495            let rows = select(id.0)?;
496            if let Some((data_type, data)) = rows.into_iter().next() {
497                let json_data = match data_type {
498                    DataType::Zstd => {
499                        let decompressed = zstd::decode_all(&data[..])?;
500                        String::from_utf8(decompressed)?
501                    }
502                    DataType::Json => String::from_utf8(data)?,
503                };
504                let thread = DbThread::from_json(json_data.as_bytes())?;
505                Ok(Some(thread))
506            } else {
507                Ok(None)
508            }
509        })
510    }
511
512    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
513        let connection = self.connection.clone();
514
515        self.executor
516            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
517    }
518
519    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
520        let connection = self.connection.clone();
521
522        self.executor.spawn(async move {
523            let connection = connection.lock();
524
525            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
526                DELETE FROM threads WHERE id = ?
527            "})?;
528
529            delete(id.0)?;
530
531            Ok(())
532        })
533    }
534
535    pub fn delete_threads(&self) -> Task<Result<()>> {
536        let connection = self.connection.clone();
537
538        self.executor.spawn(async move {
539            let connection = connection.lock();
540
541            let mut delete = connection.exec_bound::<()>(indoc! {"
542                DELETE FROM threads
543            "})?;
544
545            delete(())?;
546
547            Ok(())
548        })
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use chrono::{DateTime, TimeZone, Utc};
556    use collections::HashMap;
557    use gpui::TestAppContext;
558    use std::sync::Arc;
559
560    #[test]
561    fn test_shared_thread_roundtrip() {
562        let original = SharedThread {
563            title: "Test Thread".into(),
564            messages: vec![],
565            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
566            model: None,
567            version: SharedThread::VERSION.to_string(),
568        };
569
570        let bytes = original.to_bytes().expect("Failed to serialize");
571        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
572
573        assert_eq!(restored.title, original.title);
574        assert_eq!(restored.version, original.version);
575        assert_eq!(restored.updated_at, original.updated_at);
576    }
577
578    #[test]
579    fn test_imported_flag_defaults_to_false() {
580        // Simulate deserializing a thread without the imported field (backwards compatibility).
581        let json = r#"{
582            "title": "Old Thread",
583            "messages": [],
584            "updated_at": "2024-01-01T00:00:00Z"
585        }"#;
586
587        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
588
589        assert!(
590            !db_thread.imported,
591            "Legacy threads without imported field should default to false"
592        );
593    }
594
595    fn session_id(value: &str) -> acp::SessionId {
596        acp::SessionId::new(Arc::<str>::from(value))
597    }
598
599    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
600        DbThread {
601            title: title.to_string().into(),
602            messages: Vec::new(),
603            updated_at,
604            detailed_summary: None,
605            initial_project_snapshot: None,
606            cumulative_token_usage: Default::default(),
607            request_token_usage: HashMap::default(),
608            model: None,
609            profile: None,
610            imported: false,
611            subagent_context: None,
612            git_worktree_info: None,
613        }
614    }
615
616    #[gpui::test]
617    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
618        let database = ThreadsDatabase::new(cx.executor()).unwrap();
619
620        let older_id = session_id("thread-a");
621        let newer_id = session_id("thread-b");
622
623        let older_thread = make_thread(
624            "Thread A",
625            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
626        );
627        let newer_thread = make_thread(
628            "Thread B",
629            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
630        );
631
632        database
633            .save_thread(older_id.clone(), older_thread)
634            .await
635            .unwrap();
636        database
637            .save_thread(newer_id.clone(), newer_thread)
638            .await
639            .unwrap();
640
641        let entries = database.list_threads().await.unwrap();
642        assert_eq!(entries.len(), 2);
643        assert_eq!(entries[0].id, newer_id);
644        assert_eq!(entries[1].id, older_id);
645    }
646
647    #[gpui::test]
648    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
649        let database = ThreadsDatabase::new(cx.executor()).unwrap();
650
651        let thread_id = session_id("thread-a");
652        let original_thread = make_thread(
653            "Thread A",
654            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
655        );
656        let updated_thread = make_thread(
657            "Thread B",
658            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
659        );
660
661        database
662            .save_thread(thread_id.clone(), original_thread)
663            .await
664            .unwrap();
665        database
666            .save_thread(thread_id.clone(), updated_thread)
667            .await
668            .unwrap();
669
670        let entries = database.list_threads().await.unwrap();
671        assert_eq!(entries.len(), 1);
672        assert_eq!(entries[0].id, thread_id);
673        assert_eq!(entries[0].title.as_ref(), "Thread B");
674        assert_eq!(
675            entries[0].updated_at,
676            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
677        );
678    }
679
680    #[test]
681    fn test_subagent_context_defaults_to_none() {
682        let json = r#"{
683            "title": "Old Thread",
684            "messages": [],
685            "updated_at": "2024-01-01T00:00:00Z"
686        }"#;
687
688        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
689
690        assert!(
691            db_thread.subagent_context.is_none(),
692            "Legacy threads without subagent_context should default to None"
693        );
694    }
695
696    #[gpui::test]
697    async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
698        let database = ThreadsDatabase::new(cx.executor()).unwrap();
699
700        let parent_id = session_id("parent-thread");
701        let child_id = session_id("child-thread");
702
703        let mut child_thread = make_thread(
704            "Subagent Thread",
705            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
706        );
707        child_thread.subagent_context = Some(crate::SubagentContext {
708            parent_thread_id: parent_id.clone(),
709            depth: 2,
710        });
711
712        database
713            .save_thread(child_id.clone(), child_thread)
714            .await
715            .unwrap();
716
717        let loaded = database
718            .load_thread(child_id)
719            .await
720            .unwrap()
721            .expect("thread should exist");
722
723        let context = loaded
724            .subagent_context
725            .expect("subagent_context should be restored");
726        assert_eq!(context.parent_thread_id, parent_id);
727        assert_eq!(context.depth, 2);
728    }
729
730    #[gpui::test]
731    async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
732        let database = ThreadsDatabase::new(cx.executor()).unwrap();
733
734        let thread_id = session_id("regular-thread");
735        let thread = make_thread(
736            "Regular Thread",
737            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
738        );
739
740        database
741            .save_thread(thread_id.clone(), thread)
742            .await
743            .unwrap();
744
745        let loaded = database
746            .load_thread(thread_id)
747            .await
748            .unwrap()
749            .expect("thread should exist");
750
751        assert!(
752            loaded.subagent_context.is_none(),
753            "Regular threads should have no subagent_context"
754        );
755    }
756
757    #[gpui::test]
758    async fn test_git_worktree_info_roundtrip(cx: &mut TestAppContext) {
759        let database = ThreadsDatabase::new(cx.executor()).unwrap();
760
761        let thread_id = session_id("worktree-thread");
762        let mut thread = make_thread(
763            "Worktree Thread",
764            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
765        );
766        thread.git_worktree_info = Some(AgentGitWorktreeInfo {
767            branch: "zed/agent/a4Xiu".to_string(),
768            worktree_path: std::path::PathBuf::from("/repo/worktrees/zed/agent/a4Xiu"),
769            base_ref: "main".to_string(),
770        });
771
772        database
773            .save_thread(thread_id.clone(), thread)
774            .await
775            .unwrap();
776
777        let loaded = database
778            .load_thread(thread_id)
779            .await
780            .unwrap()
781            .expect("thread should exist");
782
783        let info = loaded
784            .git_worktree_info
785            .expect("git_worktree_info should be restored");
786        assert_eq!(info.branch, "zed/agent/a4Xiu");
787        assert_eq!(
788            info.worktree_path,
789            std::path::PathBuf::from("/repo/worktrees/zed/agent/a4Xiu")
790        );
791        assert_eq!(info.base_ref, "main");
792    }
793
794    #[gpui::test]
795    async fn test_session_list_includes_worktree_meta(cx: &mut TestAppContext) {
796        let database = ThreadsDatabase::new(cx.executor()).unwrap();
797
798        // Save a thread with worktree info
799        let worktree_id = session_id("wt-thread");
800        let mut worktree_thread = make_thread(
801            "With Worktree",
802            Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
803        );
804        worktree_thread.git_worktree_info = Some(AgentGitWorktreeInfo {
805            branch: "zed/agent/bR9kz".to_string(),
806            worktree_path: std::path::PathBuf::from("/repo/worktrees/zed/agent/bR9kz"),
807            base_ref: "develop".to_string(),
808        });
809
810        database
811            .save_thread(worktree_id.clone(), worktree_thread)
812            .await
813            .unwrap();
814
815        // Save a thread without worktree info
816        let plain_id = session_id("plain-thread");
817        let plain_thread = make_thread(
818            "Without Worktree",
819            Utc.with_ymd_and_hms(2024, 6, 15, 11, 0, 0).unwrap(),
820        );
821
822        database
823            .save_thread(plain_id.clone(), plain_thread)
824            .await
825            .unwrap();
826
827        // List threads and verify worktree_branch is populated correctly
828        let threads = database.list_threads().await.unwrap();
829        assert_eq!(threads.len(), 2);
830
831        let wt_entry = threads
832            .iter()
833            .find(|t| t.id == worktree_id)
834            .expect("should find worktree thread");
835        assert_eq!(wt_entry.worktree_branch.as_deref(), Some("zed/agent/bR9kz"));
836
837        let plain_entry = threads
838            .iter()
839            .find(|t| t.id == plain_id)
840            .expect("should find plain thread");
841        assert!(
842            plain_entry.worktree_branch.is_none(),
843            "plain thread should have no worktree_branch"
844        );
845    }
846}