thread_import.rs

  1use acp_thread::AgentSessionListRequest;
  2use agent::ThreadStore;
  3use agent_client_protocol as acp;
  4use chrono::Utc;
  5use collections::HashSet;
  6use db::kvp::Dismissable;
  7use fs::Fs;
  8use futures::FutureExt as _;
  9use gpui::{
 10    App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
 11    Render, SharedString, Task, WeakEntity, Window,
 12};
 13use notifications::status_toast::{StatusToast, ToastIcon};
 14use project::{AgentId, AgentRegistryStore, AgentServerStore};
 15use remote::RemoteConnectionOptions;
 16use ui::{
 17    Checkbox, KeyBinding, ListItem, ListItemSpacing, Modal, ModalFooter, ModalHeader, Section,
 18    prelude::*,
 19};
 20use util::ResultExt;
 21use workspace::{ModalView, MultiWorkspace, Workspace};
 22
 23use crate::{
 24    Agent, AgentPanel,
 25    agent_connection_store::AgentConnectionStore,
 26    thread_metadata_store::{ThreadMetadata, ThreadMetadataStore, ThreadWorktreePaths},
 27};
 28
 29pub struct AcpThreadImportOnboarding;
 30
 31impl AcpThreadImportOnboarding {
 32    pub fn dismissed(cx: &App) -> bool {
 33        <Self as Dismissable>::dismissed(cx)
 34    }
 35
 36    pub fn dismiss(cx: &mut App) {
 37        <Self as Dismissable>::set_dismissed(true, cx);
 38    }
 39}
 40
 41impl Dismissable for AcpThreadImportOnboarding {
 42    const KEY: &'static str = "dismissed-acp-thread-import";
 43}
 44
 45#[derive(Clone)]
 46struct AgentEntry {
 47    agent_id: AgentId,
 48    display_name: SharedString,
 49    icon_path: Option<SharedString>,
 50}
 51
 52pub struct ThreadImportModal {
 53    focus_handle: FocusHandle,
 54    workspace: WeakEntity<Workspace>,
 55    multi_workspace: WeakEntity<MultiWorkspace>,
 56    agent_entries: Vec<AgentEntry>,
 57    unchecked_agents: HashSet<AgentId>,
 58    selected_index: Option<usize>,
 59    is_importing: bool,
 60    last_error: Option<SharedString>,
 61}
 62
 63impl ThreadImportModal {
 64    pub fn new(
 65        agent_server_store: Entity<AgentServerStore>,
 66        agent_registry_store: Entity<AgentRegistryStore>,
 67        workspace: WeakEntity<Workspace>,
 68        multi_workspace: WeakEntity<MultiWorkspace>,
 69        _window: &mut Window,
 70        cx: &mut Context<Self>,
 71    ) -> Self {
 72        AcpThreadImportOnboarding::dismiss(cx);
 73
 74        let agent_entries = agent_server_store
 75            .read(cx)
 76            .external_agents()
 77            .map(|agent_id| {
 78                let display_name = agent_server_store
 79                    .read(cx)
 80                    .agent_display_name(agent_id)
 81                    .or_else(|| {
 82                        agent_registry_store
 83                            .read(cx)
 84                            .agent(agent_id)
 85                            .map(|agent| agent.name().clone())
 86                    })
 87                    .unwrap_or_else(|| agent_id.0.clone());
 88                let icon_path = agent_server_store
 89                    .read(cx)
 90                    .agent_icon(agent_id)
 91                    .or_else(|| {
 92                        agent_registry_store
 93                            .read(cx)
 94                            .agent(agent_id)
 95                            .and_then(|agent| agent.icon_path().cloned())
 96                    });
 97
 98                AgentEntry {
 99                    agent_id: agent_id.clone(),
100                    display_name,
101                    icon_path,
102                }
103            })
104            .collect::<Vec<_>>();
105
106        Self {
107            focus_handle: cx.focus_handle(),
108            workspace,
109            multi_workspace,
110            agent_entries,
111            unchecked_agents: HashSet::default(),
112            selected_index: None,
113            is_importing: false,
114            last_error: None,
115        }
116    }
117
118    fn agent_ids(&self) -> Vec<AgentId> {
119        self.agent_entries
120            .iter()
121            .map(|entry| entry.agent_id.clone())
122            .collect()
123    }
124
125    fn toggle_agent_checked(&mut self, agent_id: AgentId, cx: &mut Context<Self>) {
126        if self.unchecked_agents.contains(&agent_id) {
127            self.unchecked_agents.remove(&agent_id);
128        } else {
129            self.unchecked_agents.insert(agent_id);
130        }
131        cx.notify();
132    }
133
134    fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context<Self>) {
135        if self.agent_entries.is_empty() {
136            return;
137        }
138        self.selected_index = Some(match self.selected_index {
139            Some(ix) if ix + 1 >= self.agent_entries.len() => 0,
140            Some(ix) => ix + 1,
141            None => 0,
142        });
143        cx.notify();
144    }
145
146    fn select_previous(
147        &mut self,
148        _: &menu::SelectPrevious,
149        _window: &mut Window,
150        cx: &mut Context<Self>,
151    ) {
152        if self.agent_entries.is_empty() {
153            return;
154        }
155        self.selected_index = Some(match self.selected_index {
156            Some(0) => self.agent_entries.len() - 1,
157            Some(ix) => ix - 1,
158            None => self.agent_entries.len() - 1,
159        });
160        cx.notify();
161    }
162
163    fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
164        if let Some(ix) = self.selected_index {
165            if let Some(entry) = self.agent_entries.get(ix) {
166                self.toggle_agent_checked(entry.agent_id.clone(), cx);
167            }
168        }
169    }
170
171    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
172        cx.emit(DismissEvent);
173    }
174
175    fn import_threads(
176        &mut self,
177        _: &menu::SecondaryConfirm,
178        _: &mut Window,
179        cx: &mut Context<Self>,
180    ) {
181        if self.is_importing {
182            return;
183        }
184
185        let Some(multi_workspace) = self.multi_workspace.upgrade() else {
186            self.is_importing = false;
187            cx.notify();
188            return;
189        };
190
191        let stores = resolve_agent_connection_stores(&multi_workspace, cx);
192        if stores.is_empty() {
193            log::error!("Did not find any workspaces to import from");
194            self.is_importing = false;
195            cx.notify();
196            return;
197        }
198
199        self.is_importing = true;
200        self.last_error = None;
201        cx.notify();
202
203        let agent_ids = self
204            .agent_ids()
205            .into_iter()
206            .filter(|agent_id| !self.unchecked_agents.contains(agent_id))
207            .collect::<Vec<_>>();
208
209        let existing_sessions = ThreadMetadataStore::global(cx)
210            .read(cx)
211            .entry_ids()
212            .collect::<HashSet<_>>();
213
214        let task = find_threads_to_import(agent_ids, existing_sessions, stores, cx);
215        cx.spawn(async move |this, cx| {
216            let result = task.await;
217            this.update(cx, |this, cx| match result {
218                Ok(threads) => {
219                    let imported_count = threads.len();
220                    ThreadMetadataStore::global(cx)
221                        .update(cx, |store, cx| store.save_all(threads, cx));
222                    this.is_importing = false;
223                    this.last_error = None;
224                    this.show_imported_threads_toast(imported_count, cx);
225                    cx.emit(DismissEvent);
226                }
227                Err(error) => {
228                    this.is_importing = false;
229                    this.last_error = Some(error.to_string().into());
230                    cx.notify();
231                }
232            })
233        })
234        .detach_and_log_err(cx);
235    }
236
237    fn show_imported_threads_toast(&self, imported_count: usize, cx: &mut App) {
238        let status_toast = if imported_count == 0 {
239            StatusToast::new("No threads found to import.", cx, |this, _cx| {
240                this.icon(ToastIcon::new(IconName::Info).color(Color::Muted))
241                    .dismiss_button(true)
242            })
243        } else {
244            let message = if imported_count == 1 {
245                "Imported 1 thread.".to_string()
246            } else {
247                format!("Imported {imported_count} threads.")
248            };
249            StatusToast::new(message, cx, |this, _cx| {
250                this.icon(ToastIcon::new(IconName::Check).color(Color::Success))
251                    .dismiss_button(true)
252            })
253        };
254
255        self.workspace
256            .update(cx, |workspace, cx| {
257                workspace.toggle_status_toast(status_toast, cx);
258            })
259            .log_err();
260    }
261}
262
263impl EventEmitter<DismissEvent> for ThreadImportModal {}
264
265impl Focusable for ThreadImportModal {
266    fn focus_handle(&self, _cx: &App) -> FocusHandle {
267        self.focus_handle.clone()
268    }
269}
270
271impl ModalView for ThreadImportModal {}
272
273impl Render for ThreadImportModal {
274    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
275        let has_agents = !self.agent_entries.is_empty();
276        let disabled_import_thread = self.is_importing
277            || !has_agents
278            || self.unchecked_agents.len() == self.agent_entries.len();
279
280        let agent_rows = self
281            .agent_entries
282            .iter()
283            .enumerate()
284            .map(|(ix, entry)| {
285                let is_checked = !self.unchecked_agents.contains(&entry.agent_id);
286                let is_focused = self.selected_index == Some(ix);
287
288                ListItem::new(("thread-import-agent", ix))
289                    .rounded()
290                    .spacing(ListItemSpacing::Sparse)
291                    .focused(is_focused)
292                    .disabled(self.is_importing)
293                    .child(
294                        h_flex()
295                            .w_full()
296                            .gap_2()
297                            .when(!is_checked, |this| this.opacity(0.6))
298                            .child(if let Some(icon_path) = entry.icon_path.clone() {
299                                Icon::from_external_svg(icon_path)
300                                    .color(Color::Muted)
301                                    .size(IconSize::Small)
302                            } else {
303                                Icon::new(IconName::Sparkle)
304                                    .color(Color::Muted)
305                                    .size(IconSize::Small)
306                            })
307                            .child(Label::new(entry.display_name.clone())),
308                    )
309                    .end_slot(Checkbox::new(
310                        ("thread-import-agent-checkbox", ix),
311                        if is_checked {
312                            ToggleState::Selected
313                        } else {
314                            ToggleState::Unselected
315                        },
316                    ))
317                    .on_click({
318                        let agent_id = entry.agent_id.clone();
319                        cx.listener(move |this, _event, _window, cx| {
320                            this.toggle_agent_checked(agent_id.clone(), cx);
321                        })
322                    })
323            })
324            .collect::<Vec<_>>();
325
326        v_flex()
327            .id("thread-import-modal")
328            .key_context("ThreadImportModal")
329            .w(rems(34.))
330            .elevation_3(cx)
331            .overflow_hidden()
332            .track_focus(&self.focus_handle)
333            .on_action(cx.listener(Self::cancel))
334            .on_action(cx.listener(Self::confirm))
335            .on_action(cx.listener(Self::select_next))
336            .on_action(cx.listener(Self::select_previous))
337            .on_action(cx.listener(Self::import_threads))
338            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
339                this.focus_handle.focus(window, cx);
340            }))
341            .child(
342                Modal::new("import-threads", None)
343                    .header(
344                        ModalHeader::new()
345                            .headline("Import ACP Threads")
346                            .description(
347                                "Import threads from your ACP agents — whether started in Zed or another client. \
348                                Choose which agents to include, and their threads will appear in your archive."
349                            )
350                            .show_dismiss_button(true),
351
352                    )
353                    .section(
354                        Section::new().child(
355                            v_flex()
356                                .id("thread-import-agent-list")
357                                .max_h(rems_from_px(320.))
358                                .pb_1()
359                                .overflow_y_scroll()
360                                .when(has_agents, |this| this.children(agent_rows))
361                                .when(!has_agents, |this| {
362                                    this.child(
363                                        Label::new("No ACP agents available.")
364                                            .color(Color::Muted)
365                                            .size(LabelSize::Small),
366                                    )
367                                }),
368                        ),
369                    )
370                    .footer(
371                        ModalFooter::new()
372                            .when_some(self.last_error.clone(), |this, error| {
373                                this.start_slot(
374                                    Label::new(error)
375                                        .size(LabelSize::Small)
376                                        .color(Color::Error)
377                                        .truncate(),
378                                )
379                            })
380                            .end_slot(
381                                Button::new("import-threads", "Import Threads")
382                                    .loading(self.is_importing)
383                                    .disabled(disabled_import_thread)
384                                    .key_binding(
385                                        KeyBinding::for_action(&menu::SecondaryConfirm, cx)
386                                            .map(|kb| kb.size(rems_from_px(12.))),
387                                    )
388                                    .on_click(cx.listener(|this, _, window, cx| {
389                                        this.import_threads(&menu::SecondaryConfirm, window, cx);
390                                    })),
391                            ),
392                    ),
393            )
394    }
395}
396
397fn resolve_agent_connection_stores(
398    multi_workspace: &Entity<MultiWorkspace>,
399    cx: &App,
400) -> Vec<Entity<AgentConnectionStore>> {
401    let mut stores = Vec::new();
402    let mut included_local_store = false;
403
404    for workspace in multi_workspace.read(cx).workspaces() {
405        let workspace = workspace.read(cx);
406        let project = workspace.project().read(cx);
407
408        // We only want to include scores from one local workspace, since we
409        // know that they live on the same machine
410        let include_store = if project.is_remote() {
411            true
412        } else if project.is_local() && !included_local_store {
413            included_local_store = true;
414            true
415        } else {
416            false
417        };
418
419        if !include_store {
420            continue;
421        }
422
423        if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
424            stores.push(panel.read(cx).connection_store().clone());
425        }
426    }
427
428    stores
429}
430
431fn find_threads_to_import(
432    agent_ids: Vec<AgentId>,
433    existing_sessions: HashSet<acp::SessionId>,
434    stores: Vec<Entity<AgentConnectionStore>>,
435    cx: &mut App,
436) -> Task<anyhow::Result<Vec<ThreadMetadata>>> {
437    let mut wait_for_connection_tasks = Vec::new();
438
439    for store in stores {
440        let remote_connection = store
441            .read(cx)
442            .project()
443            .read(cx)
444            .remote_connection_options(cx);
445
446        for agent_id in agent_ids.clone() {
447            let agent = Agent::from(agent_id.clone());
448            let server = agent.server(<dyn Fs>::global(cx), ThreadStore::global(cx));
449            let entry = store.update(cx, |store, cx| store.request_connection(agent, server, cx));
450
451            wait_for_connection_tasks.push(entry.read(cx).wait_for_connection().map({
452                let remote_connection = remote_connection.clone();
453                move |state| (agent_id, remote_connection, state)
454            }));
455        }
456    }
457
458    let mut session_list_tasks = Vec::new();
459    cx.spawn(async move |cx| {
460        let results = futures::future::join_all(wait_for_connection_tasks).await;
461        for (agent_id, remote_connection, result) in results {
462            let Some(state) = result.log_err() else {
463                continue;
464            };
465            let Some(list) = cx.update(|cx| state.connection.session_list(cx)) else {
466                continue;
467            };
468            let task = cx.update(|cx| {
469                list.list_sessions(AgentSessionListRequest::default(), cx)
470                    .map({
471                        let remote_connection = remote_connection.clone();
472                        move |response| (agent_id, remote_connection, response)
473                    })
474            });
475            session_list_tasks.push(task);
476        }
477
478        let mut sessions_by_agent = Vec::new();
479        let results = futures::future::join_all(session_list_tasks).await;
480        for (agent_id, remote_connection, result) in results {
481            let Some(response) = result.log_err() else {
482                continue;
483            };
484            sessions_by_agent.push(SessionByAgent {
485                agent_id,
486                remote_connection,
487                sessions: response.sessions,
488            });
489        }
490
491        Ok(collect_importable_threads(
492            sessions_by_agent,
493            existing_sessions,
494        ))
495    })
496}
497
498struct SessionByAgent {
499    agent_id: AgentId,
500    remote_connection: Option<RemoteConnectionOptions>,
501    sessions: Vec<acp_thread::AgentSessionInfo>,
502}
503
504fn collect_importable_threads(
505    sessions_by_agent: Vec<SessionByAgent>,
506    mut existing_sessions: HashSet<acp::SessionId>,
507) -> Vec<ThreadMetadata> {
508    let mut to_insert = Vec::new();
509    for SessionByAgent {
510        agent_id,
511        remote_connection,
512        sessions,
513    } in sessions_by_agent
514    {
515        for session in sessions {
516            if !existing_sessions.insert(session.session_id.clone()) {
517                continue;
518            }
519            let Some(folder_paths) = session.work_dirs else {
520                continue;
521            };
522            to_insert.push(ThreadMetadata {
523                session_id: session.session_id,
524                agent_id: agent_id.clone(),
525                title: session
526                    .title
527                    .unwrap_or_else(|| crate::DEFAULT_THREAD_TITLE.into()),
528                updated_at: session.updated_at.unwrap_or_else(|| Utc::now()),
529                created_at: session.created_at,
530                worktree_paths: ThreadWorktreePaths::from_folder_paths(&folder_paths),
531                remote_connection: remote_connection.clone(),
532                archived: true,
533            });
534        }
535    }
536    to_insert
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use acp_thread::AgentSessionInfo;
543    use chrono::Utc;
544    use std::path::Path;
545    use workspace::PathList;
546
547    fn make_session(
548        session_id: &str,
549        title: Option<&str>,
550        work_dirs: Option<PathList>,
551        updated_at: Option<chrono::DateTime<Utc>>,
552        created_at: Option<chrono::DateTime<Utc>>,
553    ) -> AgentSessionInfo {
554        AgentSessionInfo {
555            session_id: acp::SessionId::new(session_id),
556            title: title.map(|t| SharedString::from(t.to_string())),
557            work_dirs,
558            updated_at,
559            created_at,
560            meta: None,
561        }
562    }
563
564    #[test]
565    fn test_collect_skips_sessions_already_in_existing_set() {
566        let existing = HashSet::from_iter(vec![acp::SessionId::new("existing-1")]);
567        let paths = PathList::new(&[Path::new("/project")]);
568
569        let sessions_by_agent = vec![SessionByAgent {
570            agent_id: AgentId::new("agent-a"),
571            remote_connection: None,
572            sessions: vec![
573                make_session(
574                    "existing-1",
575                    Some("Already There"),
576                    Some(paths.clone()),
577                    None,
578                    None,
579                ),
580                make_session("new-1", Some("Brand New"), Some(paths), None, None),
581            ],
582        }];
583
584        let result = collect_importable_threads(sessions_by_agent, existing);
585
586        assert_eq!(result.len(), 1);
587        assert_eq!(result[0].session_id.0.as_ref(), "new-1");
588        assert_eq!(result[0].title.as_ref(), "Brand New");
589    }
590
591    #[test]
592    fn test_collect_skips_sessions_without_work_dirs() {
593        let existing = HashSet::default();
594        let paths = PathList::new(&[Path::new("/project")]);
595
596        let sessions_by_agent = vec![SessionByAgent {
597            agent_id: AgentId::new("agent-a"),
598            remote_connection: None,
599            sessions: vec![
600                make_session("has-dirs", Some("With Dirs"), Some(paths), None, None),
601                make_session("no-dirs", Some("No Dirs"), None, None, None),
602            ],
603        }];
604
605        let result = collect_importable_threads(sessions_by_agent, existing);
606
607        assert_eq!(result.len(), 1);
608        assert_eq!(result[0].session_id.0.as_ref(), "has-dirs");
609    }
610
611    #[test]
612    fn test_collect_marks_all_imported_threads_as_archived() {
613        let existing = HashSet::default();
614        let paths = PathList::new(&[Path::new("/project")]);
615
616        let sessions_by_agent = vec![SessionByAgent {
617            agent_id: AgentId::new("agent-a"),
618            remote_connection: None,
619            sessions: vec![
620                make_session("s1", Some("Thread 1"), Some(paths.clone()), None, None),
621                make_session("s2", Some("Thread 2"), Some(paths), None, None),
622            ],
623        }];
624
625        let result = collect_importable_threads(sessions_by_agent, existing);
626
627        assert_eq!(result.len(), 2);
628        assert!(result.iter().all(|t| t.archived));
629    }
630
631    #[test]
632    fn test_collect_assigns_correct_agent_id_per_session() {
633        let existing = HashSet::default();
634        let paths = PathList::new(&[Path::new("/project")]);
635
636        let sessions_by_agent = vec![
637            SessionByAgent {
638                agent_id: AgentId::new("agent-a"),
639                remote_connection: None,
640                sessions: vec![make_session(
641                    "s1",
642                    Some("From A"),
643                    Some(paths.clone()),
644                    None,
645                    None,
646                )],
647            },
648            SessionByAgent {
649                agent_id: AgentId::new("agent-b"),
650                remote_connection: None,
651                sessions: vec![make_session("s2", Some("From B"), Some(paths), None, None)],
652            },
653        ];
654
655        let result = collect_importable_threads(sessions_by_agent, existing);
656
657        assert_eq!(result.len(), 2);
658        let s1 = result
659            .iter()
660            .find(|t| t.session_id.0.as_ref() == "s1")
661            .unwrap();
662        let s2 = result
663            .iter()
664            .find(|t| t.session_id.0.as_ref() == "s2")
665            .unwrap();
666        assert_eq!(s1.agent_id.as_ref(), "agent-a");
667        assert_eq!(s2.agent_id.as_ref(), "agent-b");
668    }
669
670    #[test]
671    fn test_collect_deduplicates_across_agents() {
672        let existing = HashSet::default();
673        let paths = PathList::new(&[Path::new("/project")]);
674
675        let sessions_by_agent = vec![
676            SessionByAgent {
677                agent_id: AgentId::new("agent-a"),
678                remote_connection: None,
679                sessions: vec![make_session(
680                    "shared-session",
681                    Some("From A"),
682                    Some(paths.clone()),
683                    None,
684                    None,
685                )],
686            },
687            SessionByAgent {
688                agent_id: AgentId::new("agent-b"),
689                remote_connection: None,
690                sessions: vec![make_session(
691                    "shared-session",
692                    Some("From B"),
693                    Some(paths),
694                    None,
695                    None,
696                )],
697            },
698        ];
699
700        let result = collect_importable_threads(sessions_by_agent, existing);
701
702        assert_eq!(result.len(), 1);
703        assert_eq!(result[0].session_id.0.as_ref(), "shared-session");
704        assert_eq!(
705            result[0].agent_id.as_ref(),
706            "agent-a",
707            "first agent encountered should win"
708        );
709    }
710
711    #[test]
712    fn test_collect_all_existing_returns_empty() {
713        let paths = PathList::new(&[Path::new("/project")]);
714        let existing =
715            HashSet::from_iter(vec![acp::SessionId::new("s1"), acp::SessionId::new("s2")]);
716
717        let sessions_by_agent = vec![SessionByAgent {
718            agent_id: AgentId::new("agent-a"),
719            remote_connection: None,
720            sessions: vec![
721                make_session("s1", Some("T1"), Some(paths.clone()), None, None),
722                make_session("s2", Some("T2"), Some(paths), None, None),
723            ],
724        }];
725
726        let result = collect_importable_threads(sessions_by_agent, existing);
727        assert!(result.is_empty());
728    }
729}