thread_import.rs

   1use acp_thread::AgentSessionListRequest;
   2use agent::ThreadStore;
   3use agent_client_protocol as acp;
   4use chrono::Utc;
   5use collections::HashSet;
   6use db::kvp::Dismissable;
   7use db::sqlez;
   8use fs::Fs;
   9use futures::FutureExt as _;
  10use gpui::{
  11    App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
  12    Render, SharedString, Task, WeakEntity, Window,
  13};
  14use itertools::Itertools as _;
  15use notifications::status_toast::StatusToast;
  16use project::{AgentId, AgentRegistryStore, AgentServerStore};
  17use release_channel::ReleaseChannel;
  18use remote::RemoteConnectionOptions;
  19use ui::{
  20    Checkbox, KeyBinding, ListItem, ListItemSpacing, Modal, ModalFooter, ModalHeader, Section,
  21    prelude::*,
  22};
  23use util::ResultExt;
  24use workspace::{ModalView, MultiWorkspace, Workspace};
  25
  26use crate::{
  27    Agent, AgentPanel,
  28    agent_connection_store::AgentConnectionStore,
  29    thread_metadata_store::{ThreadId, ThreadMetadata, ThreadMetadataStore, WorktreePaths},
  30};
  31
  32pub struct AcpThreadImportOnboarding;
  33pub struct CrossChannelImportOnboarding;
  34
  35impl AcpThreadImportOnboarding {
  36    pub fn dismissed(cx: &App) -> bool {
  37        <Self as Dismissable>::dismissed(cx)
  38    }
  39
  40    pub fn dismiss(cx: &mut App) {
  41        <Self as Dismissable>::set_dismissed(true, cx);
  42    }
  43}
  44
  45impl Dismissable for AcpThreadImportOnboarding {
  46    const KEY: &'static str = "dismissed-acp-thread-import";
  47}
  48
  49impl CrossChannelImportOnboarding {
  50    pub fn dismissed(cx: &App) -> bool {
  51        <Self as Dismissable>::dismissed(cx)
  52    }
  53
  54    pub fn dismiss(cx: &mut App) {
  55        <Self as Dismissable>::set_dismissed(true, cx);
  56    }
  57}
  58
  59impl Dismissable for CrossChannelImportOnboarding {
  60    const KEY: &'static str = "dismissed-cross-channel-thread-import";
  61}
  62
  63/// Returns the list of non-Dev, non-current release channels that have
  64/// at least one thread in their database.  The result is suitable for
  65/// building a user-facing message ("from Zed Preview and Nightly").
  66pub fn channels_with_threads(cx: &App) -> Vec<ReleaseChannel> {
  67    let Some(current_channel) = ReleaseChannel::try_global(cx) else {
  68        return Vec::new();
  69    };
  70    let database_dir = paths::database_dir();
  71
  72    ReleaseChannel::ALL
  73        .iter()
  74        .copied()
  75        .filter(|channel| {
  76            *channel != current_channel
  77                && *channel != ReleaseChannel::Dev
  78                && channel_has_threads(database_dir, *channel)
  79        })
  80        .collect()
  81}
  82
  83#[derive(Clone)]
  84struct AgentEntry {
  85    agent_id: AgentId,
  86    display_name: SharedString,
  87    icon_path: Option<SharedString>,
  88}
  89
  90pub struct ThreadImportModal {
  91    focus_handle: FocusHandle,
  92    workspace: WeakEntity<Workspace>,
  93    multi_workspace: WeakEntity<MultiWorkspace>,
  94    agent_entries: Vec<AgentEntry>,
  95    unchecked_agents: HashSet<AgentId>,
  96    selected_index: Option<usize>,
  97    is_importing: bool,
  98    last_error: Option<SharedString>,
  99}
 100
 101impl ThreadImportModal {
 102    pub fn new(
 103        agent_server_store: Entity<AgentServerStore>,
 104        agent_registry_store: Entity<AgentRegistryStore>,
 105        workspace: WeakEntity<Workspace>,
 106        multi_workspace: WeakEntity<MultiWorkspace>,
 107        _window: &mut Window,
 108        cx: &mut Context<Self>,
 109    ) -> Self {
 110        AcpThreadImportOnboarding::dismiss(cx);
 111
 112        let agent_entries = agent_server_store
 113            .read(cx)
 114            .external_agents()
 115            .map(|agent_id| {
 116                let display_name = agent_server_store
 117                    .read(cx)
 118                    .agent_display_name(agent_id)
 119                    .or_else(|| {
 120                        agent_registry_store
 121                            .read(cx)
 122                            .agent(agent_id)
 123                            .map(|agent| agent.name().clone())
 124                    })
 125                    .unwrap_or_else(|| agent_id.0.clone());
 126                let icon_path = agent_server_store
 127                    .read(cx)
 128                    .agent_icon(agent_id)
 129                    .or_else(|| {
 130                        agent_registry_store
 131                            .read(cx)
 132                            .agent(agent_id)
 133                            .and_then(|agent| agent.icon_path().cloned())
 134                    });
 135
 136                AgentEntry {
 137                    agent_id: agent_id.clone(),
 138                    display_name,
 139                    icon_path,
 140                }
 141            })
 142            .sorted_unstable_by_key(|entry| entry.display_name.to_lowercase())
 143            .collect::<Vec<_>>();
 144
 145        Self {
 146            focus_handle: cx.focus_handle(),
 147            workspace,
 148            multi_workspace,
 149            agent_entries,
 150            unchecked_agents: HashSet::default(),
 151            selected_index: None,
 152            is_importing: false,
 153            last_error: None,
 154        }
 155    }
 156
 157    fn agent_ids(&self) -> Vec<AgentId> {
 158        self.agent_entries
 159            .iter()
 160            .map(|entry| entry.agent_id.clone())
 161            .collect()
 162    }
 163
 164    fn toggle_agent_checked(&mut self, agent_id: AgentId, cx: &mut Context<Self>) {
 165        if self.unchecked_agents.contains(&agent_id) {
 166            self.unchecked_agents.remove(&agent_id);
 167        } else {
 168            self.unchecked_agents.insert(agent_id);
 169        }
 170        cx.notify();
 171    }
 172
 173    fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context<Self>) {
 174        if self.agent_entries.is_empty() {
 175            return;
 176        }
 177        self.selected_index = Some(match self.selected_index {
 178            Some(ix) if ix + 1 >= self.agent_entries.len() => 0,
 179            Some(ix) => ix + 1,
 180            None => 0,
 181        });
 182        cx.notify();
 183    }
 184
 185    fn select_previous(
 186        &mut self,
 187        _: &menu::SelectPrevious,
 188        _window: &mut Window,
 189        cx: &mut Context<Self>,
 190    ) {
 191        if self.agent_entries.is_empty() {
 192            return;
 193        }
 194        self.selected_index = Some(match self.selected_index {
 195            Some(0) => self.agent_entries.len() - 1,
 196            Some(ix) => ix - 1,
 197            None => self.agent_entries.len() - 1,
 198        });
 199        cx.notify();
 200    }
 201
 202    fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
 203        if let Some(ix) = self.selected_index {
 204            if let Some(entry) = self.agent_entries.get(ix) {
 205                self.toggle_agent_checked(entry.agent_id.clone(), cx);
 206            }
 207        }
 208    }
 209
 210    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
 211        cx.emit(DismissEvent);
 212    }
 213
 214    fn import_threads(
 215        &mut self,
 216        _: &menu::SecondaryConfirm,
 217        _: &mut Window,
 218        cx: &mut Context<Self>,
 219    ) {
 220        if self.is_importing {
 221            return;
 222        }
 223
 224        let Some(multi_workspace) = self.multi_workspace.upgrade() else {
 225            self.is_importing = false;
 226            cx.notify();
 227            return;
 228        };
 229
 230        let stores = resolve_agent_connection_stores(&multi_workspace, cx);
 231        if stores.is_empty() {
 232            log::error!("Did not find any workspaces to import from");
 233            self.is_importing = false;
 234            cx.notify();
 235            return;
 236        }
 237
 238        self.is_importing = true;
 239        self.last_error = None;
 240        cx.notify();
 241
 242        let agent_ids = self
 243            .agent_ids()
 244            .into_iter()
 245            .filter(|agent_id| !self.unchecked_agents.contains(agent_id))
 246            .collect::<Vec<_>>();
 247
 248        let existing_sessions: HashSet<acp::SessionId> = ThreadMetadataStore::global(cx)
 249            .read(cx)
 250            .entries()
 251            .filter_map(|m| m.session_id.clone())
 252            .collect();
 253
 254        let task = find_threads_to_import(agent_ids, existing_sessions, stores, cx);
 255        cx.spawn(async move |this, cx| {
 256            let result = task.await;
 257            this.update(cx, |this, cx| match result {
 258                Ok(threads) => {
 259                    let imported_count = threads.len();
 260                    ThreadMetadataStore::global(cx)
 261                        .update(cx, |store, cx| store.save_all(threads, cx));
 262                    this.is_importing = false;
 263                    this.last_error = None;
 264                    this.show_imported_threads_toast(imported_count, cx);
 265                    cx.emit(DismissEvent);
 266                }
 267                Err(error) => {
 268                    this.is_importing = false;
 269                    this.last_error = Some(error.to_string().into());
 270                    cx.notify();
 271                }
 272            })
 273        })
 274        .detach_and_log_err(cx);
 275    }
 276
 277    fn show_imported_threads_toast(&self, imported_count: usize, cx: &mut App) {
 278        let status_toast = if imported_count == 0 {
 279            StatusToast::new("No threads found to import.", cx, |this, _cx| {
 280                this.icon(
 281                    Icon::new(IconName::Info)
 282                        .size(IconSize::Small)
 283                        .color(Color::Muted),
 284                )
 285                .dismiss_button(true)
 286            })
 287        } else {
 288            let message = if imported_count == 1 {
 289                "Imported 1 thread.".to_string()
 290            } else {
 291                format!("Imported {imported_count} threads.")
 292            };
 293            StatusToast::new(message, cx, |this, _cx| {
 294                this.icon(
 295                    Icon::new(IconName::Check)
 296                        .size(IconSize::Small)
 297                        .color(Color::Success),
 298                )
 299                .dismiss_button(true)
 300            })
 301        };
 302
 303        self.workspace
 304            .update(cx, |workspace, cx| {
 305                workspace.toggle_status_toast(status_toast, cx);
 306            })
 307            .log_err();
 308    }
 309}
 310
 311impl EventEmitter<DismissEvent> for ThreadImportModal {}
 312
 313impl Focusable for ThreadImportModal {
 314    fn focus_handle(&self, _cx: &App) -> FocusHandle {
 315        self.focus_handle.clone()
 316    }
 317}
 318
 319impl ModalView for ThreadImportModal {}
 320
 321impl Render for ThreadImportModal {
 322    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 323        let has_agents = !self.agent_entries.is_empty();
 324        let disabled_import_thread = self.is_importing
 325            || !has_agents
 326            || self.unchecked_agents.len() == self.agent_entries.len();
 327
 328        let agent_rows = self
 329            .agent_entries
 330            .iter()
 331            .enumerate()
 332            .map(|(ix, entry)| {
 333                let is_checked = !self.unchecked_agents.contains(&entry.agent_id);
 334                let is_focused = self.selected_index == Some(ix);
 335
 336                ListItem::new(("thread-import-agent", ix))
 337                    .rounded()
 338                    .spacing(ListItemSpacing::Sparse)
 339                    .focused(is_focused)
 340                    .disabled(self.is_importing)
 341                    .child(
 342                        h_flex()
 343                            .w_full()
 344                            .gap_2()
 345                            .when(!is_checked, |this| this.opacity(0.6))
 346                            .child(if let Some(icon_path) = entry.icon_path.clone() {
 347                                Icon::from_external_svg(icon_path)
 348                                    .color(Color::Muted)
 349                                    .size(IconSize::Small)
 350                            } else {
 351                                Icon::new(IconName::Sparkle)
 352                                    .color(Color::Muted)
 353                                    .size(IconSize::Small)
 354                            })
 355                            .child(Label::new(entry.display_name.clone())),
 356                    )
 357                    .end_slot(Checkbox::new(
 358                        ("thread-import-agent-checkbox", ix),
 359                        if is_checked {
 360                            ToggleState::Selected
 361                        } else {
 362                            ToggleState::Unselected
 363                        },
 364                    ))
 365                    .on_click({
 366                        let agent_id = entry.agent_id.clone();
 367                        cx.listener(move |this, _event, _window, cx| {
 368                            this.toggle_agent_checked(agent_id.clone(), cx);
 369                        })
 370                    })
 371            })
 372            .collect::<Vec<_>>();
 373
 374        v_flex()
 375            .id("thread-import-modal")
 376            .key_context("ThreadImportModal")
 377            .w(rems(34.))
 378            .elevation_3(cx)
 379            .overflow_hidden()
 380            .track_focus(&self.focus_handle)
 381            .on_action(cx.listener(Self::cancel))
 382            .on_action(cx.listener(Self::confirm))
 383            .on_action(cx.listener(Self::select_next))
 384            .on_action(cx.listener(Self::select_previous))
 385            .on_action(cx.listener(Self::import_threads))
 386            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
 387                this.focus_handle.focus(window, cx);
 388            }))
 389            .child(
 390                Modal::new("import-threads", None)
 391                    .header(
 392                        ModalHeader::new()
 393                            .headline("Import External Agent Threads")
 394                            .description(
 395                                "Import threads from agents like Claude Agent, Codex, and more, whether started in Zed or another client. \
 396                                Choose which agents to include, and their threads will appear in your thread history."
 397                            )
 398                            .show_dismiss_button(true),
 399
 400                    )
 401                    .section(
 402                        Section::new().child(
 403                            v_flex()
 404                                .id("thread-import-agent-list")
 405                                .max_h(rems_from_px(320.))
 406                                .pb_1()
 407                                .overflow_y_scroll()
 408                                .when(has_agents, |this| this.children(agent_rows))
 409                                .when(!has_agents, |this| {
 410                                    this.child(
 411                                        Label::new("No ACP agents available.")
 412                                            .color(Color::Muted)
 413                                            .size(LabelSize::Small),
 414                                    )
 415                                }),
 416                        ),
 417                    )
 418                    .footer(
 419                        ModalFooter::new()
 420                            .when_some(self.last_error.clone(), |this, error| {
 421                                this.start_slot(
 422                                    Label::new(error)
 423                                        .size(LabelSize::Small)
 424                                        .color(Color::Error)
 425                                        .truncate(),
 426                                )
 427                            })
 428                            .end_slot(
 429                                Button::new("import-threads", "Import Threads")
 430                                    .loading(self.is_importing)
 431                                    .disabled(disabled_import_thread)
 432                                    .key_binding(
 433                                        KeyBinding::for_action(&menu::SecondaryConfirm, cx)
 434                                            .map(|kb| kb.size(rems_from_px(12.))),
 435                                    )
 436                                    .on_click(cx.listener(|this, _, window, cx| {
 437                                        this.import_threads(&menu::SecondaryConfirm, window, cx);
 438                                    })),
 439                            ),
 440                    ),
 441            )
 442    }
 443}
 444
 445fn resolve_agent_connection_stores(
 446    multi_workspace: &Entity<MultiWorkspace>,
 447    cx: &App,
 448) -> Vec<Entity<AgentConnectionStore>> {
 449    let mut stores = Vec::new();
 450    let mut included_local_store = false;
 451
 452    for workspace in multi_workspace.read(cx).workspaces() {
 453        let workspace = workspace.read(cx);
 454        let project = workspace.project().read(cx);
 455
 456        // We only want to include scores from one local workspace, since we
 457        // know that they live on the same machine
 458        let include_store = if project.is_remote() {
 459            true
 460        } else if project.is_local() && !included_local_store {
 461            included_local_store = true;
 462            true
 463        } else {
 464            false
 465        };
 466
 467        if !include_store {
 468            continue;
 469        }
 470
 471        if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
 472            stores.push(panel.read(cx).connection_store().clone());
 473        }
 474    }
 475
 476    stores
 477}
 478
 479fn find_threads_to_import(
 480    agent_ids: Vec<AgentId>,
 481    existing_sessions: HashSet<acp::SessionId>,
 482    stores: Vec<Entity<AgentConnectionStore>>,
 483    cx: &mut App,
 484) -> Task<anyhow::Result<Vec<ThreadMetadata>>> {
 485    let mut wait_for_connection_tasks = Vec::new();
 486
 487    for store in stores {
 488        let remote_connection = store
 489            .read(cx)
 490            .project()
 491            .read(cx)
 492            .remote_connection_options(cx);
 493
 494        for agent_id in agent_ids.clone() {
 495            let agent = Agent::from(agent_id.clone());
 496            let server = agent.server(<dyn Fs>::global(cx), ThreadStore::global(cx));
 497            let entry = store.update(cx, |store, cx| store.request_connection(agent, server, cx));
 498
 499            wait_for_connection_tasks.push(entry.read(cx).wait_for_connection().map({
 500                let remote_connection = remote_connection.clone();
 501                move |state| (agent_id, remote_connection, state)
 502            }));
 503        }
 504    }
 505
 506    cx.spawn(async move |cx| {
 507        let results = futures::future::join_all(wait_for_connection_tasks).await;
 508
 509        let mut page_tasks = Vec::new();
 510        for (agent_id, remote_connection, result) in results {
 511            let Some(state) = result.log_err() else {
 512                continue;
 513            };
 514            let Some(list) = cx.update(|cx| state.connection.session_list(cx)) else {
 515                continue;
 516            };
 517            page_tasks.push(cx.spawn({
 518                let list = list.clone();
 519                async move |cx| collect_all_sessions(agent_id, remote_connection, list, cx).await
 520            }));
 521        }
 522
 523        let sessions_by_agent = futures::future::join_all(page_tasks)
 524            .await
 525            .into_iter()
 526            .filter_map(|result| result.log_err())
 527            .collect();
 528
 529        Ok(collect_importable_threads(
 530            sessions_by_agent,
 531            existing_sessions,
 532        ))
 533    })
 534}
 535
 536async fn collect_all_sessions(
 537    agent_id: AgentId,
 538    remote_connection: Option<RemoteConnectionOptions>,
 539    list: std::rc::Rc<dyn acp_thread::AgentSessionList>,
 540    cx: &mut gpui::AsyncApp,
 541) -> anyhow::Result<SessionByAgent> {
 542    let mut sessions = Vec::new();
 543    let mut cursor: Option<String> = None;
 544    loop {
 545        let request = AgentSessionListRequest {
 546            cursor: cursor.clone(),
 547            ..Default::default()
 548        };
 549        let task = cx.update(|cx| list.list_sessions(request, cx));
 550        let response = task.await?;
 551        sessions.extend(response.sessions);
 552        match response.next_cursor {
 553            Some(next) if Some(&next) != cursor.as_ref() => cursor = Some(next),
 554            _ => break,
 555        }
 556    }
 557    Ok(SessionByAgent {
 558        agent_id,
 559        remote_connection,
 560        sessions,
 561    })
 562}
 563
 564struct SessionByAgent {
 565    agent_id: AgentId,
 566    remote_connection: Option<RemoteConnectionOptions>,
 567    sessions: Vec<acp_thread::AgentSessionInfo>,
 568}
 569
 570fn collect_importable_threads(
 571    sessions_by_agent: Vec<SessionByAgent>,
 572    mut existing_sessions: HashSet<acp::SessionId>,
 573) -> Vec<ThreadMetadata> {
 574    let mut to_insert = Vec::new();
 575    for SessionByAgent {
 576        agent_id,
 577        remote_connection,
 578        sessions,
 579    } in sessions_by_agent
 580    {
 581        for session in sessions {
 582            if !existing_sessions.insert(session.session_id.clone()) {
 583                continue;
 584            }
 585            let Some(folder_paths) = session.work_dirs else {
 586                continue;
 587            };
 588            to_insert.push(ThreadMetadata {
 589                thread_id: ThreadId::new(),
 590                session_id: Some(session.session_id),
 591                agent_id: agent_id.clone(),
 592                title: session.title,
 593                updated_at: session.updated_at.unwrap_or_else(|| Utc::now()),
 594                created_at: session.created_at,
 595                interacted_at: None,
 596                worktree_paths: WorktreePaths::from_folder_paths(&folder_paths),
 597                remote_connection: remote_connection.clone(),
 598                archived: true,
 599            });
 600        }
 601    }
 602    to_insert
 603}
 604
 605pub fn import_threads_from_other_channels(_workspace: &mut Workspace, cx: &mut Context<Workspace>) {
 606    let database_dir = paths::database_dir().clone();
 607    import_threads_from_other_channels_in(database_dir, cx);
 608}
 609
 610fn import_threads_from_other_channels_in(
 611    database_dir: std::path::PathBuf,
 612    cx: &mut Context<Workspace>,
 613) {
 614    let current_channel = ReleaseChannel::global(cx);
 615
 616    let existing_thread_ids: HashSet<ThreadId> = ThreadMetadataStore::global(cx)
 617        .read(cx)
 618        .entries()
 619        .map(|metadata| metadata.thread_id)
 620        .collect();
 621
 622    let workspace_handle = cx.weak_entity();
 623    cx.spawn(async move |_this, cx| {
 624        let mut imported_threads = Vec::new();
 625
 626        for channel in &ReleaseChannel::ALL {
 627            if *channel == current_channel || *channel == ReleaseChannel::Dev {
 628                continue;
 629            }
 630
 631            match read_threads_from_channel(&database_dir, *channel) {
 632                Ok(threads) => {
 633                    let new_threads = threads
 634                        .into_iter()
 635                        .filter(|thread| !existing_thread_ids.contains(&thread.thread_id));
 636                    imported_threads.extend(new_threads);
 637                }
 638                Err(error) => {
 639                    log::warn!(
 640                        "Failed to read threads from {} channel database: {}",
 641                        channel.dev_name(),
 642                        error
 643                    );
 644                }
 645            }
 646        }
 647
 648        let imported_count = imported_threads.len();
 649
 650        cx.update(|cx| {
 651            ThreadMetadataStore::global(cx)
 652                .update(cx, |store, cx| store.save_all(imported_threads, cx));
 653
 654            show_cross_channel_import_toast(&workspace_handle, imported_count, cx);
 655        })
 656    })
 657    .detach();
 658}
 659
 660fn channel_has_threads(database_dir: &std::path::Path, channel: ReleaseChannel) -> bool {
 661    let db_path = db::db_path(database_dir, channel);
 662    if !db_path.exists() {
 663        return false;
 664    }
 665    let connection = sqlez::connection::Connection::open_file(&db_path.to_string_lossy());
 666    connection
 667        .select_row::<bool>("SELECT 1 FROM sidebar_threads LIMIT 1")
 668        .ok()
 669        .and_then(|mut query| query().ok().flatten())
 670        .unwrap_or(false)
 671}
 672
 673fn read_threads_from_channel(
 674    database_dir: &std::path::Path,
 675    channel: ReleaseChannel,
 676) -> anyhow::Result<Vec<ThreadMetadata>> {
 677    let db_path = db::db_path(database_dir, channel);
 678    if !db_path.exists() {
 679        return Ok(Vec::new());
 680    }
 681    let connection = sqlez::connection::Connection::open_file(&db_path.to_string_lossy());
 682    crate::thread_metadata_store::list_thread_metadata_from_connection(&connection)
 683}
 684
 685fn show_cross_channel_import_toast(
 686    workspace: &WeakEntity<Workspace>,
 687    imported_count: usize,
 688    cx: &mut App,
 689) {
 690    let status_toast = if imported_count == 0 {
 691        StatusToast::new("No new threads found to import.", cx, |this, _cx| {
 692            this.icon(Icon::new(IconName::Info).color(Color::Muted))
 693                .dismiss_button(true)
 694        })
 695    } else {
 696        let message = if imported_count == 1 {
 697            "Imported 1 thread from other channels.".to_string()
 698        } else {
 699            format!("Imported {imported_count} threads from other channels.")
 700        };
 701        StatusToast::new(message, cx, |this, _cx| {
 702            this.icon(Icon::new(IconName::Check).color(Color::Success))
 703                .dismiss_button(true)
 704        })
 705    };
 706
 707    workspace
 708        .update(cx, |workspace, cx| {
 709            workspace.toggle_status_toast(status_toast, cx);
 710        })
 711        .log_err();
 712}
 713
 714#[cfg(test)]
 715mod tests {
 716    use super::*;
 717    use acp_thread::AgentSessionInfo;
 718    use chrono::Utc;
 719    use gpui::TestAppContext;
 720    use std::path::Path;
 721    use workspace::PathList;
 722
 723    fn make_session(
 724        session_id: &str,
 725        title: Option<&str>,
 726        work_dirs: Option<PathList>,
 727        updated_at: Option<chrono::DateTime<Utc>>,
 728        created_at: Option<chrono::DateTime<Utc>>,
 729    ) -> AgentSessionInfo {
 730        AgentSessionInfo {
 731            session_id: acp::SessionId::new(session_id),
 732            title: title.map(|t| SharedString::from(t.to_string())),
 733            work_dirs,
 734            updated_at,
 735            created_at,
 736            meta: None,
 737        }
 738    }
 739
 740    #[test]
 741    fn test_collect_skips_sessions_already_in_existing_set() {
 742        let existing = HashSet::from_iter(vec![acp::SessionId::new("existing-1")]);
 743        let paths = PathList::new(&[Path::new("/project")]);
 744
 745        let sessions_by_agent = vec![SessionByAgent {
 746            agent_id: AgentId::new("agent-a"),
 747            remote_connection: None,
 748            sessions: vec![
 749                make_session(
 750                    "existing-1",
 751                    Some("Already There"),
 752                    Some(paths.clone()),
 753                    None,
 754                    None,
 755                ),
 756                make_session("new-1", Some("Brand New"), Some(paths), None, None),
 757            ],
 758        }];
 759
 760        let result = collect_importable_threads(sessions_by_agent, existing);
 761
 762        assert_eq!(result.len(), 1);
 763        assert_eq!(result[0].session_id.as_ref().unwrap().0.as_ref(), "new-1");
 764        assert_eq!(result[0].display_title(), "Brand New");
 765    }
 766
 767    #[test]
 768    fn test_collect_skips_sessions_without_work_dirs() {
 769        let existing = HashSet::default();
 770        let paths = PathList::new(&[Path::new("/project")]);
 771
 772        let sessions_by_agent = vec![SessionByAgent {
 773            agent_id: AgentId::new("agent-a"),
 774            remote_connection: None,
 775            sessions: vec![
 776                make_session("has-dirs", Some("With Dirs"), Some(paths), None, None),
 777                make_session("no-dirs", Some("No Dirs"), None, None, None),
 778            ],
 779        }];
 780
 781        let result = collect_importable_threads(sessions_by_agent, existing);
 782
 783        assert_eq!(result.len(), 1);
 784        assert_eq!(
 785            result[0].session_id.as_ref().unwrap().0.as_ref(),
 786            "has-dirs"
 787        );
 788    }
 789
 790    #[test]
 791    fn test_collect_marks_all_imported_threads_as_archived() {
 792        let existing = HashSet::default();
 793        let paths = PathList::new(&[Path::new("/project")]);
 794
 795        let sessions_by_agent = vec![SessionByAgent {
 796            agent_id: AgentId::new("agent-a"),
 797            remote_connection: None,
 798            sessions: vec![
 799                make_session("s1", Some("Thread 1"), Some(paths.clone()), None, None),
 800                make_session("s2", Some("Thread 2"), Some(paths), None, None),
 801            ],
 802        }];
 803
 804        let result = collect_importable_threads(sessions_by_agent, existing);
 805
 806        assert_eq!(result.len(), 2);
 807        assert!(result.iter().all(|t| t.archived));
 808    }
 809
 810    #[test]
 811    fn test_collect_assigns_correct_agent_id_per_session() {
 812        let existing = HashSet::default();
 813        let paths = PathList::new(&[Path::new("/project")]);
 814
 815        let sessions_by_agent = vec![
 816            SessionByAgent {
 817                agent_id: AgentId::new("agent-a"),
 818                remote_connection: None,
 819                sessions: vec![make_session(
 820                    "s1",
 821                    Some("From A"),
 822                    Some(paths.clone()),
 823                    None,
 824                    None,
 825                )],
 826            },
 827            SessionByAgent {
 828                agent_id: AgentId::new("agent-b"),
 829                remote_connection: None,
 830                sessions: vec![make_session("s2", Some("From B"), Some(paths), None, None)],
 831            },
 832        ];
 833
 834        let result = collect_importable_threads(sessions_by_agent, existing);
 835
 836        assert_eq!(result.len(), 2);
 837        let s1 = result
 838            .iter()
 839            .find(|t| t.session_id.as_ref().map(|s| s.0.as_ref()) == Some("s1"))
 840            .unwrap();
 841        let s2 = result
 842            .iter()
 843            .find(|t| t.session_id.as_ref().map(|s| s.0.as_ref()) == Some("s2"))
 844            .unwrap();
 845        assert_eq!(s1.agent_id.as_ref(), "agent-a");
 846        assert_eq!(s2.agent_id.as_ref(), "agent-b");
 847    }
 848
 849    #[test]
 850    fn test_collect_deduplicates_across_agents() {
 851        let existing = HashSet::default();
 852        let paths = PathList::new(&[Path::new("/project")]);
 853
 854        let sessions_by_agent = vec![
 855            SessionByAgent {
 856                agent_id: AgentId::new("agent-a"),
 857                remote_connection: None,
 858                sessions: vec![make_session(
 859                    "shared-session",
 860                    Some("From A"),
 861                    Some(paths.clone()),
 862                    None,
 863                    None,
 864                )],
 865            },
 866            SessionByAgent {
 867                agent_id: AgentId::new("agent-b"),
 868                remote_connection: None,
 869                sessions: vec![make_session(
 870                    "shared-session",
 871                    Some("From B"),
 872                    Some(paths),
 873                    None,
 874                    None,
 875                )],
 876            },
 877        ];
 878
 879        let result = collect_importable_threads(sessions_by_agent, existing);
 880
 881        assert_eq!(result.len(), 1);
 882        assert_eq!(
 883            result[0].session_id.as_ref().unwrap().0.as_ref(),
 884            "shared-session"
 885        );
 886        assert_eq!(
 887            result[0].agent_id.as_ref(),
 888            "agent-a",
 889            "first agent encountered should win"
 890        );
 891    }
 892
 893    #[test]
 894    fn test_collect_all_existing_returns_empty() {
 895        let paths = PathList::new(&[Path::new("/project")]);
 896        let existing =
 897            HashSet::from_iter(vec![acp::SessionId::new("s1"), acp::SessionId::new("s2")]);
 898
 899        let sessions_by_agent = vec![SessionByAgent {
 900            agent_id: AgentId::new("agent-a"),
 901            remote_connection: None,
 902            sessions: vec![
 903                make_session("s1", Some("T1"), Some(paths.clone()), None, None),
 904                make_session("s2", Some("T2"), Some(paths), None, None),
 905            ],
 906        }];
 907
 908        let result = collect_importable_threads(sessions_by_agent, existing);
 909        assert!(result.is_empty());
 910    }
 911
 912    fn create_channel_db(
 913        db_dir: &std::path::Path,
 914        channel: ReleaseChannel,
 915    ) -> db::sqlez::connection::Connection {
 916        let db_path = db::db_path(db_dir, channel);
 917        std::fs::create_dir_all(db_path.parent().unwrap()).unwrap();
 918        let connection = db::sqlez::connection::Connection::open_file(&db_path.to_string_lossy());
 919        crate::thread_metadata_store::run_thread_metadata_migrations(&connection);
 920        connection
 921    }
 922
 923    fn insert_thread(
 924        connection: &db::sqlez::connection::Connection,
 925        title: &str,
 926        updated_at: &str,
 927        archived: bool,
 928    ) {
 929        let thread_id = uuid::Uuid::new_v4();
 930        let session_id = uuid::Uuid::new_v4().to_string();
 931        connection
 932            .exec_bound::<(uuid::Uuid, &str, &str, &str, bool)>(
 933                "INSERT INTO sidebar_threads \
 934                 (thread_id, session_id, title, updated_at, archived) \
 935                 VALUES (?1, ?2, ?3, ?4, ?5)",
 936            )
 937            .unwrap()((thread_id, session_id.as_str(), title, updated_at, archived))
 938        .unwrap();
 939    }
 940
 941    #[test]
 942    fn test_returns_empty_when_channel_db_missing() {
 943        let dir = tempfile::tempdir().unwrap();
 944        let threads = read_threads_from_channel(dir.path(), ReleaseChannel::Nightly).unwrap();
 945        assert!(threads.is_empty());
 946    }
 947
 948    #[test]
 949    fn test_preserves_archived_state() {
 950        let dir = tempfile::tempdir().unwrap();
 951        let connection = create_channel_db(dir.path(), ReleaseChannel::Nightly);
 952
 953        insert_thread(&connection, "Active Thread", "2025-01-15T10:00:00Z", false);
 954        insert_thread(&connection, "Archived Thread", "2025-01-15T09:00:00Z", true);
 955        drop(connection);
 956
 957        let threads = read_threads_from_channel(dir.path(), ReleaseChannel::Nightly).unwrap();
 958        assert_eq!(threads.len(), 2);
 959
 960        let active = threads
 961            .iter()
 962            .find(|t| t.display_title().as_ref() == "Active Thread")
 963            .unwrap();
 964        assert!(!active.archived);
 965
 966        let archived = threads
 967            .iter()
 968            .find(|t| t.display_title().as_ref() == "Archived Thread")
 969            .unwrap();
 970        assert!(archived.archived);
 971    }
 972
 973    fn init_test(cx: &mut TestAppContext) {
 974        let fs = fs::FakeFs::new(cx.executor());
 975        cx.update(|cx| {
 976            let settings_store = settings::SettingsStore::test(cx);
 977            cx.set_global(settings_store);
 978            theme_settings::init(theme::LoadThemes::JustBase, cx);
 979            release_channel::init("0.0.0".parse().unwrap(), cx);
 980            <dyn fs::Fs>::set_global(fs, cx);
 981            ThreadMetadataStore::init_global(cx);
 982        });
 983        cx.run_until_parked();
 984    }
 985
 986    /// Returns two release channels that are not the current one and not Dev.
 987    /// This ensures tests work regardless of which release channel branch
 988    /// they run on.
 989    fn foreign_channels(cx: &TestAppContext) -> (ReleaseChannel, ReleaseChannel) {
 990        let current = cx.update(|cx| ReleaseChannel::global(cx));
 991        let mut channels = ReleaseChannel::ALL
 992            .iter()
 993            .copied()
 994            .filter(|ch| *ch != current && *ch != ReleaseChannel::Dev);
 995        (channels.next().unwrap(), channels.next().unwrap())
 996    }
 997
 998    #[gpui::test]
 999    async fn test_import_threads_from_other_channels(cx: &mut TestAppContext) {
1000        init_test(cx);
1001
1002        let dir = tempfile::tempdir().unwrap();
1003        let database_dir = dir.path().to_path_buf();
1004
1005        let (channel_a, channel_b) = foreign_channels(cx);
1006
1007        // Set up databases for two foreign channels.
1008        let db_a = create_channel_db(dir.path(), channel_a);
1009        insert_thread(&db_a, "Thread A1", "2025-01-15T10:00:00Z", false);
1010        insert_thread(&db_a, "Thread A2", "2025-01-15T11:00:00Z", true);
1011        drop(db_a);
1012
1013        let db_b = create_channel_db(dir.path(), channel_b);
1014        insert_thread(&db_b, "Thread B1", "2025-01-15T12:00:00Z", false);
1015        drop(db_b);
1016
1017        // Create a workspace and run the import.
1018        let fs = fs::FakeFs::new(cx.executor());
1019        let project = project::Project::test(fs, [], cx).await;
1020        let multi_workspace =
1021            cx.add_window(|window, cx| MultiWorkspace::test_new(project, window, cx));
1022        let workspace_entity = multi_workspace
1023            .read_with(cx, |mw, _cx| mw.workspace().clone())
1024            .unwrap();
1025        let mut vcx = gpui::VisualTestContext::from_window(multi_workspace.into(), cx);
1026
1027        workspace_entity.update_in(&mut vcx, |_workspace, _window, cx| {
1028            import_threads_from_other_channels_in(database_dir, cx);
1029        });
1030        cx.run_until_parked();
1031
1032        // Verify all three threads were imported into the store.
1033        cx.update(|cx| {
1034            let store = ThreadMetadataStore::global(cx);
1035            let store = store.read(cx);
1036            let titles: collections::HashSet<String> = store
1037                .entries()
1038                .map(|m| m.display_title().to_string())
1039                .collect();
1040
1041            assert_eq!(titles.len(), 3);
1042            assert!(titles.contains("Thread A1"));
1043            assert!(titles.contains("Thread A2"));
1044            assert!(titles.contains("Thread B1"));
1045
1046            // Verify archived state is preserved.
1047            let thread_a2 = store
1048                .entries()
1049                .find(|m| m.display_title().as_ref() == "Thread A2")
1050                .unwrap();
1051            assert!(thread_a2.archived);
1052
1053            let thread_b1 = store
1054                .entries()
1055                .find(|m| m.display_title().as_ref() == "Thread B1")
1056                .unwrap();
1057            assert!(!thread_b1.archived);
1058        });
1059    }
1060
1061    #[gpui::test]
1062    async fn test_import_skips_already_existing_threads(cx: &mut TestAppContext) {
1063        init_test(cx);
1064
1065        let dir = tempfile::tempdir().unwrap();
1066        let database_dir = dir.path().to_path_buf();
1067
1068        let (channel_a, _) = foreign_channels(cx);
1069
1070        // Set up a database for a foreign channel.
1071        let db_a = create_channel_db(dir.path(), channel_a);
1072        insert_thread(&db_a, "Thread A", "2025-01-15T10:00:00Z", false);
1073        insert_thread(&db_a, "Thread B", "2025-01-15T11:00:00Z", false);
1074        drop(db_a);
1075
1076        // Read the threads so we can pre-populate one into the store.
1077        let foreign_threads = read_threads_from_channel(dir.path(), channel_a).unwrap();
1078        let thread_a = foreign_threads
1079            .iter()
1080            .find(|t| t.display_title().as_ref() == "Thread A")
1081            .unwrap()
1082            .clone();
1083
1084        // Pre-populate Thread A into the store.
1085        cx.update(|cx| {
1086            ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save(thread_a, cx));
1087        });
1088        cx.run_until_parked();
1089
1090        // Run the import.
1091        let fs = fs::FakeFs::new(cx.executor());
1092        let project = project::Project::test(fs, [], cx).await;
1093        let multi_workspace =
1094            cx.add_window(|window, cx| MultiWorkspace::test_new(project, window, cx));
1095        let workspace_entity = multi_workspace
1096            .read_with(cx, |mw, _cx| mw.workspace().clone())
1097            .unwrap();
1098        let mut vcx = gpui::VisualTestContext::from_window(multi_workspace.into(), cx);
1099
1100        workspace_entity.update_in(&mut vcx, |_workspace, _window, cx| {
1101            import_threads_from_other_channels_in(database_dir, cx);
1102        });
1103        cx.run_until_parked();
1104
1105        // Verify only Thread B was added (Thread A already existed).
1106        cx.update(|cx| {
1107            let store = ThreadMetadataStore::global(cx);
1108            let store = store.read(cx);
1109            assert_eq!(store.entries().count(), 2);
1110
1111            let titles: collections::HashSet<String> = store
1112                .entries()
1113                .map(|m| m.display_title().to_string())
1114                .collect();
1115            assert!(titles.contains("Thread A"));
1116            assert!(titles.contains("Thread B"));
1117        });
1118    }
1119}