thread_metadata_store.rs

  1use std::{path::Path, sync::Arc};
  2
  3use agent::{ThreadStore, ZED_AGENT_ID};
  4use agent_client_protocol as acp;
  5use anyhow::Result;
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use db::{
  9    sqlez::{
 10        bindable::Column, domain::Domain, statement::Statement,
 11        thread_safe_connection::ThreadSafeConnection,
 12    },
 13    sqlez_macros::sql,
 14};
 15use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt};
 16use gpui::{AppContext as _, Entity, Global, Subscription, Task};
 17use project::AgentId;
 18use ui::{App, Context, SharedString};
 19use workspace::PathList;
 20
 21pub fn init(cx: &mut App) {
 22    ThreadMetadataStore::init_global(cx);
 23
 24    if cx.has_flag::<AgentV2FeatureFlag>() {
 25        migrate_thread_metadata(cx);
 26    }
 27    cx.observe_flag::<AgentV2FeatureFlag, _>(|has_flag, cx| {
 28        if has_flag {
 29            migrate_thread_metadata(cx);
 30        }
 31    })
 32    .detach();
 33}
 34
 35/// Migrate existing thread metadata from native agent thread store to the new metadata storage.
 36///
 37/// TODO: Remove this after N weeks of shipping the sidebar
 38fn migrate_thread_metadata(cx: &mut App) {
 39    ThreadMetadataStore::global(cx).update(cx, |store, cx| {
 40        let list = store.list(cx);
 41        cx.spawn(async move |this, cx| {
 42            let Ok(list) = list.await else {
 43                return;
 44            };
 45            if list.is_empty() {
 46                this.update(cx, |this, cx| {
 47                    let metadata = ThreadStore::global(cx)
 48                        .read(cx)
 49                        .entries()
 50                        .map(|entry| ThreadMetadata {
 51                            session_id: entry.id,
 52                            agent_id: None,
 53                            title: entry.title,
 54                            updated_at: entry.updated_at,
 55                            created_at: entry.created_at,
 56                            folder_paths: entry.folder_paths,
 57                        })
 58                        .collect::<Vec<_>>();
 59                    for entry in metadata {
 60                        this.save(entry, cx).detach_and_log_err(cx);
 61                    }
 62                })
 63                .ok();
 64            }
 65        })
 66        .detach();
 67    });
 68}
 69
 70struct GlobalThreadMetadataStore(Entity<ThreadMetadataStore>);
 71impl Global for GlobalThreadMetadataStore {}
 72
 73/// Lightweight metadata for any thread (native or ACP), enough to populate
 74/// the sidebar list and route to the correct load path when clicked.
 75#[derive(Debug, Clone)]
 76pub struct ThreadMetadata {
 77    pub session_id: acp::SessionId,
 78    /// `None` for native Zed threads, `Some("claude-code")` etc. for ACP agents.
 79    pub agent_id: Option<AgentId>,
 80    pub title: SharedString,
 81    pub updated_at: DateTime<Utc>,
 82    pub created_at: Option<DateTime<Utc>>,
 83    pub folder_paths: PathList,
 84}
 85
 86pub struct ThreadMetadataStore {
 87    db: ThreadMetadataDb,
 88    session_subscriptions: HashMap<acp::SessionId, Subscription>,
 89}
 90
 91impl ThreadMetadataStore {
 92    #[cfg(not(any(test, feature = "test-support")))]
 93    pub fn init_global(cx: &mut App) {
 94        if cx.has_global::<Self>() {
 95            return;
 96        }
 97
 98        let db = THREAD_METADATA_DB.clone();
 99        let thread_store = cx.new(|cx| Self::new(db, cx));
100        cx.set_global(GlobalThreadMetadataStore(thread_store));
101    }
102
103    #[cfg(any(test, feature = "test-support"))]
104    pub fn init_global(cx: &mut App) {
105        let thread = std::thread::current();
106        let test_name = thread.name().unwrap_or("unknown_test");
107        let db_name = format!("THREAD_METADATA_DB_{}", test_name);
108        let db = smol::block_on(db::open_test_db::<ThreadMetadataDb>(&db_name));
109        let thread_store = cx.new(|cx| Self::new(ThreadMetadataDb(db), cx));
110        cx.set_global(GlobalThreadMetadataStore(thread_store));
111    }
112
113    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
114        cx.try_global::<GlobalThreadMetadataStore>()
115            .map(|store| store.0.clone())
116    }
117
118    pub fn global(cx: &App) -> Entity<Self> {
119        cx.global::<GlobalThreadMetadataStore>().0.clone()
120    }
121
122    pub fn list(&self, cx: &App) -> Task<Result<Vec<ThreadMetadata>>> {
123        let db = self.db.clone();
124        cx.background_spawn(async move {
125            let s = db.list()?;
126            Ok(s)
127        })
128    }
129
130    pub fn save(&mut self, metadata: ThreadMetadata, cx: &mut Context<Self>) -> Task<Result<()>> {
131        if !cx.has_flag::<AgentV2FeatureFlag>() {
132            return Task::ready(Ok(()));
133        }
134
135        let db = self.db.clone();
136        cx.spawn(async move |this, cx| {
137            db.save(metadata).await?;
138            this.update(cx, |_this, cx| cx.notify())
139        })
140    }
141
142    pub fn delete(
143        &mut self,
144        session_id: acp::SessionId,
145        cx: &mut Context<Self>,
146    ) -> Task<Result<()>> {
147        if !cx.has_flag::<AgentV2FeatureFlag>() {
148            return Task::ready(Ok(()));
149        }
150
151        let db = self.db.clone();
152        cx.spawn(async move |this, cx| {
153            db.delete(session_id).await?;
154            this.update(cx, |_this, cx| cx.notify())
155        })
156    }
157
158    fn new(db: ThreadMetadataDb, cx: &mut Context<Self>) -> Self {
159        let weak_store = cx.weak_entity();
160
161        cx.observe_new::<acp_thread::AcpThread>(move |thread, _window, cx| {
162            // Don't track subagent threads in the sidebar.
163            if thread.parent_session_id().is_some() {
164                return;
165            }
166
167            let thread_entity = cx.entity();
168
169            cx.on_release({
170                let weak_store = weak_store.clone();
171                move |thread, cx| {
172                    weak_store
173                        .update(cx, |store, _cx| {
174                            store.session_subscriptions.remove(thread.session_id());
175                        })
176                        .ok();
177                }
178            })
179            .detach();
180
181            weak_store
182                .update(cx, |this, cx| {
183                    let subscription = cx.subscribe(&thread_entity, Self::handle_thread_update);
184                    this.session_subscriptions
185                        .insert(thread.session_id().clone(), subscription);
186                })
187                .ok();
188        })
189        .detach();
190
191        Self {
192            db,
193            session_subscriptions: HashMap::default(),
194        }
195    }
196
197    fn handle_thread_update(
198        &mut self,
199        thread: Entity<acp_thread::AcpThread>,
200        event: &acp_thread::AcpThreadEvent,
201        cx: &mut Context<Self>,
202    ) {
203        // Don't track subagent threads in the sidebar.
204        if thread.read(cx).parent_session_id().is_some() {
205            return;
206        }
207
208        match event {
209            acp_thread::AcpThreadEvent::NewEntry
210            | acp_thread::AcpThreadEvent::EntryUpdated(_)
211            | acp_thread::AcpThreadEvent::TitleUpdated => {
212                let metadata = Self::metadata_for_acp_thread(thread.read(cx), cx);
213                self.save(metadata, cx).detach_and_log_err(cx);
214            }
215            _ => {}
216        }
217    }
218
219    fn metadata_for_acp_thread(thread: &acp_thread::AcpThread, cx: &App) -> ThreadMetadata {
220        let session_id = thread.session_id().clone();
221        let title = thread.title();
222        let updated_at = Utc::now();
223
224        let agent_id = thread.connection().agent_id();
225
226        let agent_id = if agent_id.as_ref() == ZED_AGENT_ID.as_ref() {
227            None
228        } else {
229            Some(agent_id)
230        };
231
232        let folder_paths = {
233            let project = thread.project().read(cx);
234            let paths: Vec<Arc<Path>> = project
235                .visible_worktrees(cx)
236                .map(|worktree| worktree.read(cx).abs_path())
237                .collect();
238            PathList::new(&paths)
239        };
240
241        ThreadMetadata {
242            session_id,
243            agent_id,
244            title,
245            created_at: Some(updated_at), // handled by db `ON CONFLICT`
246            updated_at,
247            folder_paths,
248        }
249    }
250}
251
252impl Global for ThreadMetadataStore {}
253
254#[derive(Clone)]
255struct ThreadMetadataDb(ThreadSafeConnection);
256
257impl Domain for ThreadMetadataDb {
258    const NAME: &str = stringify!(ThreadMetadataDb);
259
260    const MIGRATIONS: &[&str] = &[sql!(
261        CREATE TABLE IF NOT EXISTS sidebar_threads(
262            session_id TEXT PRIMARY KEY,
263            agent_id TEXT,
264            title TEXT NOT NULL,
265            updated_at TEXT NOT NULL,
266            created_at TEXT,
267            folder_paths TEXT,
268            folder_paths_order TEXT
269        ) STRICT;
270    )];
271}
272
273db::static_connection!(THREAD_METADATA_DB, ThreadMetadataDb, []);
274
275impl ThreadMetadataDb {
276    /// List all sidebar thread metadata, ordered by updated_at descending.
277    pub fn list(&self) -> anyhow::Result<Vec<ThreadMetadata>> {
278        self.select::<ThreadMetadata>(
279            "SELECT session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order \
280             FROM sidebar_threads \
281             ORDER BY updated_at DESC"
282        )?()
283    }
284
285    /// Upsert metadata for a thread.
286    pub async fn save(&self, row: ThreadMetadata) -> anyhow::Result<()> {
287        let id = row.session_id.0.clone();
288        let agent_id = row.agent_id.as_ref().map(|id| id.0.to_string());
289        let title = row.title.to_string();
290        let updated_at = row.updated_at.to_rfc3339();
291        let created_at = row.created_at.map(|dt| dt.to_rfc3339());
292        let serialized = row.folder_paths.serialize();
293        let (folder_paths, folder_paths_order) = if row.folder_paths.is_empty() {
294            (None, None)
295        } else {
296            (Some(serialized.paths), Some(serialized.order))
297        };
298
299        self.write(move |conn| {
300            let sql = "INSERT INTO sidebar_threads(session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order) \
301                       VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) \
302                       ON CONFLICT(session_id) DO UPDATE SET \
303                           agent_id = excluded.agent_id, \
304                           title = excluded.title, \
305                           updated_at = excluded.updated_at, \
306                           folder_paths = excluded.folder_paths, \
307                           folder_paths_order = excluded.folder_paths_order";
308            let mut stmt = Statement::prepare(conn, sql)?;
309            let mut i = stmt.bind(&id, 1)?;
310            i = stmt.bind(&agent_id, i)?;
311            i = stmt.bind(&title, i)?;
312            i = stmt.bind(&updated_at, i)?;
313            i = stmt.bind(&created_at, i)?;
314            i = stmt.bind(&folder_paths, i)?;
315            stmt.bind(&folder_paths_order, i)?;
316            stmt.exec()
317        })
318        .await
319    }
320
321    /// Delete metadata for a single thread.
322    pub async fn delete(&self, session_id: acp::SessionId) -> anyhow::Result<()> {
323        let id = session_id.0.clone();
324        self.write(move |conn| {
325            let mut stmt =
326                Statement::prepare(conn, "DELETE FROM sidebar_threads WHERE session_id = ?")?;
327            stmt.bind(&id, 1)?;
328            stmt.exec()
329        })
330        .await
331    }
332}
333
334impl Column for ThreadMetadata {
335    fn column(statement: &mut Statement, start_index: i32) -> anyhow::Result<(Self, i32)> {
336        let (id, next): (Arc<str>, i32) = Column::column(statement, start_index)?;
337        let (agent_id, next): (Option<String>, i32) = Column::column(statement, next)?;
338        let (title, next): (String, i32) = Column::column(statement, next)?;
339        let (updated_at_str, next): (String, i32) = Column::column(statement, next)?;
340        let (created_at_str, next): (Option<String>, i32) = Column::column(statement, next)?;
341        let (folder_paths_str, next): (Option<String>, i32) = Column::column(statement, next)?;
342        let (folder_paths_order_str, next): (Option<String>, i32) =
343            Column::column(statement, next)?;
344
345        let updated_at = DateTime::parse_from_rfc3339(&updated_at_str)?.with_timezone(&Utc);
346        let created_at = created_at_str
347            .as_deref()
348            .map(DateTime::parse_from_rfc3339)
349            .transpose()?
350            .map(|dt| dt.with_timezone(&Utc));
351
352        let folder_paths = folder_paths_str
353            .map(|paths| {
354                PathList::deserialize(&util::path_list::SerializedPathList {
355                    paths,
356                    order: folder_paths_order_str.unwrap_or_default(),
357                })
358            })
359            .unwrap_or_default();
360
361        Ok((
362            ThreadMetadata {
363                session_id: acp::SessionId::new(id),
364                agent_id: agent_id.map(|id| AgentId::new(id)),
365                title: title.into(),
366                updated_at,
367                created_at,
368                folder_paths,
369            },
370            next,
371        ))
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use acp_thread::{AgentConnection, StubAgentConnection};
379    use action_log::ActionLog;
380    use agent::DbThread;
381    use agent_client_protocol as acp;
382    use feature_flags::FeatureFlagAppExt;
383    use gpui::TestAppContext;
384    use project::FakeFs;
385    use project::Project;
386    use std::path::Path;
387    use std::rc::Rc;
388    use util::path_list::PathList;
389
390    fn make_db_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
391        DbThread {
392            title: title.to_string().into(),
393            messages: Vec::new(),
394            updated_at,
395            detailed_summary: None,
396            initial_project_snapshot: None,
397            cumulative_token_usage: Default::default(),
398            request_token_usage: Default::default(),
399            model: None,
400            profile: None,
401            imported: false,
402            subagent_context: None,
403            speed: None,
404            thinking_enabled: false,
405            thinking_effort: None,
406            draft_prompt: None,
407            ui_scroll_position: None,
408        }
409    }
410
411    #[gpui::test]
412    async fn test_migrate_thread_metadata(cx: &mut TestAppContext) {
413        cx.update(|cx| {
414            ThreadStore::init_global(cx);
415            ThreadMetadataStore::init_global(cx);
416        });
417
418        // Verify the list is empty before migration
419        let metadata_list = cx.update(|cx| {
420            let store = ThreadMetadataStore::global(cx);
421            store.read(cx).list(cx)
422        });
423
424        let list = metadata_list.await.unwrap();
425        assert_eq!(list.len(), 0);
426
427        let now = Utc::now();
428
429        // Populate the native ThreadStore via save_thread
430        let save1 = cx.update(|cx| {
431            let thread_store = ThreadStore::global(cx);
432            thread_store.update(cx, |store, cx| {
433                store.save_thread(
434                    acp::SessionId::new("session-1"),
435                    make_db_thread("Thread 1", now),
436                    PathList::default(),
437                    cx,
438                )
439            })
440        });
441        save1.await.unwrap();
442        cx.run_until_parked();
443
444        let save2 = cx.update(|cx| {
445            let thread_store = ThreadStore::global(cx);
446            thread_store.update(cx, |store, cx| {
447                store.save_thread(
448                    acp::SessionId::new("session-2"),
449                    make_db_thread("Thread 2", now),
450                    PathList::default(),
451                    cx,
452                )
453            })
454        });
455        save2.await.unwrap();
456        cx.run_until_parked();
457
458        // Run migration
459        cx.update(|cx| {
460            migrate_thread_metadata(cx);
461        });
462
463        cx.run_until_parked();
464
465        // Verify the metadata was migrated
466        let metadata_list = cx.update(|cx| {
467            let store = ThreadMetadataStore::global(cx);
468            store.read(cx).list(cx)
469        });
470
471        let list = metadata_list.await.unwrap();
472        assert_eq!(list.len(), 2);
473
474        let metadata1 = list
475            .iter()
476            .find(|m| m.session_id.0.as_ref() == "session-1")
477            .expect("session-1 should be in migrated metadata");
478        assert_eq!(metadata1.title.as_ref(), "Thread 1");
479        assert!(metadata1.agent_id.is_none());
480
481        let metadata2 = list
482            .iter()
483            .find(|m| m.session_id.0.as_ref() == "session-2")
484            .expect("session-2 should be in migrated metadata");
485        assert_eq!(metadata2.title.as_ref(), "Thread 2");
486        assert!(metadata2.agent_id.is_none());
487    }
488
489    #[gpui::test]
490    async fn test_migrate_thread_metadata_skips_when_data_exists(cx: &mut TestAppContext) {
491        cx.update(|cx| {
492            ThreadStore::init_global(cx);
493            ThreadMetadataStore::init_global(cx);
494        });
495
496        // Pre-populate the metadata store with existing data
497        let existing_metadata = ThreadMetadata {
498            session_id: acp::SessionId::new("existing-session"),
499            agent_id: None,
500            title: "Existing Thread".into(),
501            updated_at: Utc::now(),
502            created_at: Some(Utc::now()),
503            folder_paths: PathList::default(),
504        };
505
506        cx.update(|cx| {
507            let store = ThreadMetadataStore::global(cx);
508            store.update(cx, |store, cx| {
509                store.save(existing_metadata, cx).detach();
510            });
511        });
512
513        cx.run_until_parked();
514
515        // Add an entry to native thread store that should NOT be migrated
516        let save_task = cx.update(|cx| {
517            let thread_store = ThreadStore::global(cx);
518            thread_store.update(cx, |store, cx| {
519                store.save_thread(
520                    acp::SessionId::new("native-session"),
521                    make_db_thread("Native Thread", Utc::now()),
522                    PathList::default(),
523                    cx,
524                )
525            })
526        });
527        save_task.await.unwrap();
528        cx.run_until_parked();
529
530        // Run migration - should skip because metadata store is not empty
531        cx.update(|cx| {
532            migrate_thread_metadata(cx);
533        });
534
535        cx.run_until_parked();
536
537        // Verify only the existing metadata is present (migration was skipped)
538        let metadata_list = cx.update(|cx| {
539            let store = ThreadMetadataStore::global(cx);
540            store.read(cx).list(cx)
541        });
542
543        let list = metadata_list.await.unwrap();
544        assert_eq!(list.len(), 1);
545        assert_eq!(list[0].session_id.0.as_ref(), "existing-session");
546    }
547
548    #[gpui::test]
549    async fn test_subagent_threads_excluded_from_sidebar_metadata(cx: &mut TestAppContext) {
550        cx.update(|cx| {
551            let settings_store = settings::SettingsStore::test(cx);
552            cx.set_global(settings_store);
553            cx.update_flags(true, vec!["agent-v2".to_string()]);
554            ThreadStore::init_global(cx);
555            ThreadMetadataStore::init_global(cx);
556        });
557
558        let fs = FakeFs::new(cx.executor());
559        let project = Project::test(fs, None::<&Path>, cx).await;
560        let connection = Rc::new(StubAgentConnection::new());
561
562        // Create a regular (non-subagent) AcpThread.
563        let regular_thread = cx
564            .update(|cx| {
565                connection
566                    .clone()
567                    .new_session(project.clone(), PathList::default(), cx)
568            })
569            .await
570            .unwrap();
571
572        let regular_session_id = cx.read(|cx| regular_thread.read(cx).session_id().clone());
573
574        // Set a title on the regular thread to trigger a save via handle_thread_update.
575        cx.update(|cx| {
576            regular_thread.update(cx, |thread, cx| {
577                thread.set_title("Regular Thread".into(), cx).detach();
578            });
579        });
580        cx.run_until_parked();
581
582        // Create a subagent AcpThread
583        let subagent_session_id = acp::SessionId::new("subagent-session");
584        let subagent_thread = cx.update(|cx| {
585            let action_log = cx.new(|_| ActionLog::new(project.clone()));
586            cx.new(|cx| {
587                acp_thread::AcpThread::new(
588                    Some(regular_session_id.clone()),
589                    "Subagent Thread",
590                    None,
591                    connection.clone(),
592                    project.clone(),
593                    action_log,
594                    subagent_session_id.clone(),
595                    watch::Receiver::constant(acp::PromptCapabilities::new()),
596                    cx,
597                )
598            })
599        });
600
601        // Set a title on the subagent thread to trigger handle_thread_update.
602        cx.update(|cx| {
603            subagent_thread.update(cx, |thread, cx| {
604                thread
605                    .set_title("Subagent Thread Title".into(), cx)
606                    .detach();
607            });
608        });
609        cx.run_until_parked();
610
611        // List all metadata from the store.
612        let metadata_list = cx.update(|cx| {
613            let store = ThreadMetadataStore::global(cx);
614            store.read(cx).list(cx)
615        });
616
617        let list = metadata_list.await.unwrap();
618
619        // The subagent thread should NOT appear in the sidebar metadata.
620        // Only the regular thread should be listed.
621        assert_eq!(
622            list.len(),
623            1,
624            "Expected only the regular thread in sidebar metadata, \
625             but found {} entries (subagent threads are leaking into the sidebar)",
626            list.len(),
627        );
628        assert_eq!(list[0].session_id, regular_session_id);
629        assert_eq!(list[0].title.as_ref(), "Regular Thread");
630    }
631}