thread_import.rs

  1use acp_thread::AgentSessionListRequest;
  2use agent::ThreadStore;
  3use agent_client_protocol as acp;
  4use chrono::Utc;
  5use collections::HashSet;
  6use db::kvp::Dismissable;
  7use fs::Fs;
  8use futures::FutureExt as _;
  9use gpui::{
 10    App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
 11    Render, SharedString, Task, WeakEntity, Window,
 12};
 13use notifications::status_toast::{StatusToast, ToastIcon};
 14use project::{AgentId, AgentRegistryStore, AgentServerStore};
 15use ui::{
 16    Checkbox, KeyBinding, ListItem, ListItemSpacing, Modal, ModalFooter, ModalHeader, Section,
 17    prelude::*,
 18};
 19use util::ResultExt;
 20use workspace::{ModalView, MultiWorkspace, PathList, Workspace};
 21
 22use crate::{
 23    Agent, AgentPanel,
 24    agent_connection_store::AgentConnectionStore,
 25    thread_metadata_store::{ThreadMetadata, ThreadMetadataStore},
 26};
 27
 28pub struct AcpThreadImportOnboarding;
 29
 30impl AcpThreadImportOnboarding {
 31    pub fn dismissed(cx: &App) -> bool {
 32        <Self as Dismissable>::dismissed(cx)
 33    }
 34
 35    pub fn dismiss(cx: &mut App) {
 36        <Self as Dismissable>::set_dismissed(true, cx);
 37    }
 38}
 39
 40impl Dismissable for AcpThreadImportOnboarding {
 41    const KEY: &'static str = "dismissed-acp-thread-import";
 42}
 43
 44#[derive(Clone)]
 45struct AgentEntry {
 46    agent_id: AgentId,
 47    display_name: SharedString,
 48    icon_path: Option<SharedString>,
 49}
 50
 51pub struct ThreadImportModal {
 52    focus_handle: FocusHandle,
 53    workspace: WeakEntity<Workspace>,
 54    multi_workspace: WeakEntity<MultiWorkspace>,
 55    agent_entries: Vec<AgentEntry>,
 56    unchecked_agents: HashSet<AgentId>,
 57    selected_index: Option<usize>,
 58    is_importing: bool,
 59    last_error: Option<SharedString>,
 60}
 61
 62impl ThreadImportModal {
 63    pub fn new(
 64        agent_server_store: Entity<AgentServerStore>,
 65        agent_registry_store: Entity<AgentRegistryStore>,
 66        workspace: WeakEntity<Workspace>,
 67        multi_workspace: WeakEntity<MultiWorkspace>,
 68        _window: &mut Window,
 69        cx: &mut Context<Self>,
 70    ) -> Self {
 71        AcpThreadImportOnboarding::dismiss(cx);
 72
 73        let agent_entries = agent_server_store
 74            .read(cx)
 75            .external_agents()
 76            .map(|agent_id| {
 77                let display_name = agent_server_store
 78                    .read(cx)
 79                    .agent_display_name(agent_id)
 80                    .or_else(|| {
 81                        agent_registry_store
 82                            .read(cx)
 83                            .agent(agent_id)
 84                            .map(|agent| agent.name().clone())
 85                    })
 86                    .unwrap_or_else(|| agent_id.0.clone());
 87                let icon_path = agent_server_store
 88                    .read(cx)
 89                    .agent_icon(agent_id)
 90                    .or_else(|| {
 91                        agent_registry_store
 92                            .read(cx)
 93                            .agent(agent_id)
 94                            .and_then(|agent| agent.icon_path().cloned())
 95                    });
 96
 97                AgentEntry {
 98                    agent_id: agent_id.clone(),
 99                    display_name,
100                    icon_path,
101                }
102            })
103            .collect::<Vec<_>>();
104
105        Self {
106            focus_handle: cx.focus_handle(),
107            workspace,
108            multi_workspace,
109            agent_entries,
110            unchecked_agents: HashSet::default(),
111            selected_index: None,
112            is_importing: false,
113            last_error: None,
114        }
115    }
116
117    fn agent_ids(&self) -> Vec<AgentId> {
118        self.agent_entries
119            .iter()
120            .map(|entry| entry.agent_id.clone())
121            .collect()
122    }
123
124    fn toggle_agent_checked(&mut self, agent_id: AgentId, cx: &mut Context<Self>) {
125        if self.unchecked_agents.contains(&agent_id) {
126            self.unchecked_agents.remove(&agent_id);
127        } else {
128            self.unchecked_agents.insert(agent_id);
129        }
130        cx.notify();
131    }
132
133    fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context<Self>) {
134        if self.agent_entries.is_empty() {
135            return;
136        }
137        self.selected_index = Some(match self.selected_index {
138            Some(ix) if ix + 1 >= self.agent_entries.len() => 0,
139            Some(ix) => ix + 1,
140            None => 0,
141        });
142        cx.notify();
143    }
144
145    fn select_previous(
146        &mut self,
147        _: &menu::SelectPrevious,
148        _window: &mut Window,
149        cx: &mut Context<Self>,
150    ) {
151        if self.agent_entries.is_empty() {
152            return;
153        }
154        self.selected_index = Some(match self.selected_index {
155            Some(0) => self.agent_entries.len() - 1,
156            Some(ix) => ix - 1,
157            None => self.agent_entries.len() - 1,
158        });
159        cx.notify();
160    }
161
162    fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
163        if let Some(ix) = self.selected_index {
164            if let Some(entry) = self.agent_entries.get(ix) {
165                self.toggle_agent_checked(entry.agent_id.clone(), cx);
166            }
167        }
168    }
169
170    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
171        cx.emit(DismissEvent);
172    }
173
174    fn import_threads(
175        &mut self,
176        _: &menu::SecondaryConfirm,
177        _: &mut Window,
178        cx: &mut Context<Self>,
179    ) {
180        if self.is_importing {
181            return;
182        }
183
184        let Some(multi_workspace) = self.multi_workspace.upgrade() else {
185            self.is_importing = false;
186            cx.notify();
187            return;
188        };
189
190        let stores = resolve_agent_connection_stores(&multi_workspace, cx);
191        if stores.is_empty() {
192            log::error!("Did not find any workspaces to import from");
193            self.is_importing = false;
194            cx.notify();
195            return;
196        }
197
198        self.is_importing = true;
199        self.last_error = None;
200        cx.notify();
201
202        let agent_ids = self
203            .agent_ids()
204            .into_iter()
205            .filter(|agent_id| !self.unchecked_agents.contains(agent_id))
206            .collect::<Vec<_>>();
207
208        let existing_sessions = ThreadMetadataStore::global(cx)
209            .read(cx)
210            .entry_ids()
211            .collect::<HashSet<_>>();
212
213        let task = find_threads_to_import(agent_ids, existing_sessions, stores, cx);
214        cx.spawn(async move |this, cx| {
215            let result = task.await;
216            this.update(cx, |this, cx| match result {
217                Ok(threads) => {
218                    let imported_count = threads.len();
219                    ThreadMetadataStore::global(cx)
220                        .update(cx, |store, cx| store.save_all(threads, cx));
221                    this.is_importing = false;
222                    this.last_error = None;
223                    this.show_imported_threads_toast(imported_count, cx);
224                    cx.emit(DismissEvent);
225                }
226                Err(error) => {
227                    this.is_importing = false;
228                    this.last_error = Some(error.to_string().into());
229                    cx.notify();
230                }
231            })
232        })
233        .detach_and_log_err(cx);
234    }
235
236    fn show_imported_threads_toast(&self, imported_count: usize, cx: &mut App) {
237        let status_toast = if imported_count == 0 {
238            StatusToast::new("No threads found to import.", cx, |this, _cx| {
239                this.icon(ToastIcon::new(IconName::Info).color(Color::Muted))
240                    .dismiss_button(true)
241            })
242        } else {
243            let message = if imported_count == 1 {
244                "Imported 1 thread.".to_string()
245            } else {
246                format!("Imported {imported_count} threads.")
247            };
248            StatusToast::new(message, cx, |this, _cx| {
249                this.icon(ToastIcon::new(IconName::Check).color(Color::Success))
250                    .dismiss_button(true)
251            })
252        };
253
254        self.workspace
255            .update(cx, |workspace, cx| {
256                workspace.toggle_status_toast(status_toast, cx);
257            })
258            .log_err();
259    }
260}
261
262impl EventEmitter<DismissEvent> for ThreadImportModal {}
263
264impl Focusable for ThreadImportModal {
265    fn focus_handle(&self, _cx: &App) -> FocusHandle {
266        self.focus_handle.clone()
267    }
268}
269
270impl ModalView for ThreadImportModal {}
271
272impl Render for ThreadImportModal {
273    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
274        let has_agents = !self.agent_entries.is_empty();
275        let disabled_import_thread = self.is_importing
276            || !has_agents
277            || self.unchecked_agents.len() == self.agent_entries.len();
278
279        let agent_rows = self
280            .agent_entries
281            .iter()
282            .enumerate()
283            .map(|(ix, entry)| {
284                let is_checked = !self.unchecked_agents.contains(&entry.agent_id);
285                let is_focused = self.selected_index == Some(ix);
286
287                ListItem::new(("thread-import-agent", ix))
288                    .rounded()
289                    .spacing(ListItemSpacing::Sparse)
290                    .focused(is_focused)
291                    .disabled(self.is_importing)
292                    .child(
293                        h_flex()
294                            .w_full()
295                            .gap_2()
296                            .when(!is_checked, |this| this.opacity(0.6))
297                            .child(if let Some(icon_path) = entry.icon_path.clone() {
298                                Icon::from_external_svg(icon_path)
299                                    .color(Color::Muted)
300                                    .size(IconSize::Small)
301                            } else {
302                                Icon::new(IconName::Sparkle)
303                                    .color(Color::Muted)
304                                    .size(IconSize::Small)
305                            })
306                            .child(Label::new(entry.display_name.clone())),
307                    )
308                    .end_slot(Checkbox::new(
309                        ("thread-import-agent-checkbox", ix),
310                        if is_checked {
311                            ToggleState::Selected
312                        } else {
313                            ToggleState::Unselected
314                        },
315                    ))
316                    .on_click({
317                        let agent_id = entry.agent_id.clone();
318                        cx.listener(move |this, _event, _window, cx| {
319                            this.toggle_agent_checked(agent_id.clone(), cx);
320                        })
321                    })
322            })
323            .collect::<Vec<_>>();
324
325        v_flex()
326            .id("thread-import-modal")
327            .key_context("ThreadImportModal")
328            .w(rems(34.))
329            .elevation_3(cx)
330            .overflow_hidden()
331            .track_focus(&self.focus_handle)
332            .on_action(cx.listener(Self::cancel))
333            .on_action(cx.listener(Self::confirm))
334            .on_action(cx.listener(Self::select_next))
335            .on_action(cx.listener(Self::select_previous))
336            .on_action(cx.listener(Self::import_threads))
337            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
338                this.focus_handle.focus(window, cx);
339            }))
340            .child(
341                Modal::new("import-threads", None)
342                    .header(
343                        ModalHeader::new()
344                            .headline("Import ACP Threads")
345                            .description(
346                                "Import threads from your ACP agents — whether started in Zed or another client. \
347                                Choose which agents to include, and their threads will appear in your archive."
348                            )
349                            .show_dismiss_button(true),
350
351                    )
352                    .section(
353                        Section::new().child(
354                            v_flex()
355                                .id("thread-import-agent-list")
356                                .max_h(rems_from_px(320.))
357                                .pb_1()
358                                .overflow_y_scroll()
359                                .when(has_agents, |this| this.children(agent_rows))
360                                .when(!has_agents, |this| {
361                                    this.child(
362                                        Label::new("No ACP agents available.")
363                                            .color(Color::Muted)
364                                            .size(LabelSize::Small),
365                                    )
366                                }),
367                        ),
368                    )
369                    .footer(
370                        ModalFooter::new()
371                            .when_some(self.last_error.clone(), |this, error| {
372                                this.start_slot(
373                                    Label::new(error)
374                                        .size(LabelSize::Small)
375                                        .color(Color::Error)
376                                        .truncate(),
377                                )
378                            })
379                            .end_slot(
380                                Button::new("import-threads", "Import Threads")
381                                    .loading(self.is_importing)
382                                    .disabled(disabled_import_thread)
383                                    .key_binding(
384                                        KeyBinding::for_action(&menu::SecondaryConfirm, cx)
385                                            .map(|kb| kb.size(rems_from_px(12.))),
386                                    )
387                                    .on_click(cx.listener(|this, _, window, cx| {
388                                        this.import_threads(&menu::SecondaryConfirm, window, cx);
389                                    })),
390                            ),
391                    ),
392            )
393    }
394}
395
396fn resolve_agent_connection_stores(
397    multi_workspace: &Entity<MultiWorkspace>,
398    cx: &App,
399) -> Vec<Entity<AgentConnectionStore>> {
400    let mut stores = Vec::new();
401    let mut included_local_store = false;
402
403    for workspace in multi_workspace.read(cx).workspaces() {
404        let workspace = workspace.read(cx);
405        let project = workspace.project().read(cx);
406
407        // We only want to include scores from one local workspace, since we
408        // know that they live on the same machine
409        let include_store = if project.is_remote() {
410            true
411        } else if project.is_local() && !included_local_store {
412            included_local_store = true;
413            true
414        } else {
415            false
416        };
417
418        if !include_store {
419            continue;
420        }
421
422        if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
423            stores.push(panel.read(cx).connection_store().clone());
424        }
425    }
426
427    stores
428}
429
430fn find_threads_to_import(
431    agent_ids: Vec<AgentId>,
432    existing_sessions: HashSet<acp::SessionId>,
433    stores: Vec<Entity<AgentConnectionStore>>,
434    cx: &mut App,
435) -> Task<anyhow::Result<Vec<ThreadMetadata>>> {
436    let mut wait_for_connection_tasks = Vec::new();
437
438    for store in stores {
439        for agent_id in agent_ids.clone() {
440            let agent = Agent::from(agent_id.clone());
441            let server = agent.server(<dyn Fs>::global(cx), ThreadStore::global(cx));
442            let entry = store.update(cx, |store, cx| store.request_connection(agent, server, cx));
443            wait_for_connection_tasks
444                .push(entry.read(cx).wait_for_connection().map(|s| (agent_id, s)));
445        }
446    }
447
448    let mut session_list_tasks = Vec::new();
449    cx.spawn(async move |cx| {
450        let results = futures::future::join_all(wait_for_connection_tasks).await;
451        for (agent, result) in results {
452            let Some(state) = result.log_err() else {
453                continue;
454            };
455            let Some(list) = cx.update(|cx| state.connection.session_list(cx)) else {
456                continue;
457            };
458            let task = cx.update(|cx| {
459                list.list_sessions(AgentSessionListRequest::default(), cx)
460                    .map(|r| (agent, r))
461            });
462            session_list_tasks.push(task);
463        }
464
465        let mut sessions_by_agent = Vec::new();
466        let results = futures::future::join_all(session_list_tasks).await;
467        for (agent_id, result) in results {
468            let Some(response) = result.log_err() else {
469                continue;
470            };
471            sessions_by_agent.push((agent_id, response.sessions));
472        }
473
474        Ok(collect_importable_threads(
475            sessions_by_agent,
476            existing_sessions,
477        ))
478    })
479}
480
481fn collect_importable_threads(
482    sessions_by_agent: Vec<(AgentId, Vec<acp_thread::AgentSessionInfo>)>,
483    mut existing_sessions: HashSet<acp::SessionId>,
484) -> Vec<ThreadMetadata> {
485    let mut to_insert = Vec::new();
486    for (agent_id, sessions) in sessions_by_agent {
487        for session in sessions {
488            if !existing_sessions.insert(session.session_id.clone()) {
489                continue;
490            }
491            let Some(folder_paths) = session.work_dirs else {
492                continue;
493            };
494            to_insert.push(ThreadMetadata {
495                session_id: session.session_id,
496                agent_id: agent_id.clone(),
497                title: session
498                    .title
499                    .unwrap_or_else(|| crate::DEFAULT_THREAD_TITLE.into()),
500                updated_at: session.updated_at.unwrap_or_else(|| Utc::now()),
501                created_at: session.created_at,
502                folder_paths,
503                main_worktree_paths: PathList::default(),
504                archived: true,
505            });
506        }
507    }
508    to_insert
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use acp_thread::AgentSessionInfo;
515    use chrono::Utc;
516    use std::path::Path;
517    use workspace::PathList;
518
519    fn make_session(
520        session_id: &str,
521        title: Option<&str>,
522        work_dirs: Option<PathList>,
523        updated_at: Option<chrono::DateTime<Utc>>,
524        created_at: Option<chrono::DateTime<Utc>>,
525    ) -> AgentSessionInfo {
526        AgentSessionInfo {
527            session_id: acp::SessionId::new(session_id),
528            title: title.map(|t| SharedString::from(t.to_string())),
529            work_dirs,
530            updated_at,
531            created_at,
532            meta: None,
533        }
534    }
535
536    #[test]
537    fn test_collect_skips_sessions_already_in_existing_set() {
538        let existing = HashSet::from_iter(vec![acp::SessionId::new("existing-1")]);
539        let paths = PathList::new(&[Path::new("/project")]);
540
541        let sessions_by_agent = vec![(
542            AgentId::new("agent-a"),
543            vec![
544                make_session(
545                    "existing-1",
546                    Some("Already There"),
547                    Some(paths.clone()),
548                    None,
549                    None,
550                ),
551                make_session("new-1", Some("Brand New"), Some(paths), None, None),
552            ],
553        )];
554
555        let result = collect_importable_threads(sessions_by_agent, existing);
556
557        assert_eq!(result.len(), 1);
558        assert_eq!(result[0].session_id.0.as_ref(), "new-1");
559        assert_eq!(result[0].title.as_ref(), "Brand New");
560    }
561
562    #[test]
563    fn test_collect_skips_sessions_without_work_dirs() {
564        let existing = HashSet::default();
565        let paths = PathList::new(&[Path::new("/project")]);
566
567        let sessions_by_agent = vec![(
568            AgentId::new("agent-a"),
569            vec![
570                make_session("has-dirs", Some("With Dirs"), Some(paths), None, None),
571                make_session("no-dirs", Some("No Dirs"), None, None, None),
572            ],
573        )];
574
575        let result = collect_importable_threads(sessions_by_agent, existing);
576
577        assert_eq!(result.len(), 1);
578        assert_eq!(result[0].session_id.0.as_ref(), "has-dirs");
579    }
580
581    #[test]
582    fn test_collect_marks_all_imported_threads_as_archived() {
583        let existing = HashSet::default();
584        let paths = PathList::new(&[Path::new("/project")]);
585
586        let sessions_by_agent = vec![(
587            AgentId::new("agent-a"),
588            vec![
589                make_session("s1", Some("Thread 1"), Some(paths.clone()), None, None),
590                make_session("s2", Some("Thread 2"), Some(paths), None, None),
591            ],
592        )];
593
594        let result = collect_importable_threads(sessions_by_agent, existing);
595
596        assert_eq!(result.len(), 2);
597        assert!(result.iter().all(|t| t.archived));
598    }
599
600    #[test]
601    fn test_collect_assigns_correct_agent_id_per_session() {
602        let existing = HashSet::default();
603        let paths = PathList::new(&[Path::new("/project")]);
604
605        let sessions_by_agent = vec![
606            (
607                AgentId::new("agent-a"),
608                vec![make_session(
609                    "s1",
610                    Some("From A"),
611                    Some(paths.clone()),
612                    None,
613                    None,
614                )],
615            ),
616            (
617                AgentId::new("agent-b"),
618                vec![make_session("s2", Some("From B"), Some(paths), None, None)],
619            ),
620        ];
621
622        let result = collect_importable_threads(sessions_by_agent, existing);
623
624        assert_eq!(result.len(), 2);
625        let s1 = result
626            .iter()
627            .find(|t| t.session_id.0.as_ref() == "s1")
628            .unwrap();
629        let s2 = result
630            .iter()
631            .find(|t| t.session_id.0.as_ref() == "s2")
632            .unwrap();
633        assert_eq!(s1.agent_id.as_ref(), "agent-a");
634        assert_eq!(s2.agent_id.as_ref(), "agent-b");
635    }
636
637    #[test]
638    fn test_collect_deduplicates_across_agents() {
639        let existing = HashSet::default();
640        let paths = PathList::new(&[Path::new("/project")]);
641
642        let sessions_by_agent = vec![
643            (
644                AgentId::new("agent-a"),
645                vec![make_session(
646                    "shared-session",
647                    Some("From A"),
648                    Some(paths.clone()),
649                    None,
650                    None,
651                )],
652            ),
653            (
654                AgentId::new("agent-b"),
655                vec![make_session(
656                    "shared-session",
657                    Some("From B"),
658                    Some(paths),
659                    None,
660                    None,
661                )],
662            ),
663        ];
664
665        let result = collect_importable_threads(sessions_by_agent, existing);
666
667        assert_eq!(result.len(), 1);
668        assert_eq!(result[0].session_id.0.as_ref(), "shared-session");
669        assert_eq!(
670            result[0].agent_id.as_ref(),
671            "agent-a",
672            "first agent encountered should win"
673        );
674    }
675
676    #[test]
677    fn test_collect_all_existing_returns_empty() {
678        let paths = PathList::new(&[Path::new("/project")]);
679        let existing =
680            HashSet::from_iter(vec![acp::SessionId::new("s1"), acp::SessionId::new("s2")]);
681
682        let sessions_by_agent = vec![(
683            AgentId::new("agent-a"),
684            vec![
685                make_session("s1", Some("T1"), Some(paths.clone()), None, None),
686                make_session("s2", Some("T2"), Some(paths), None, None),
687            ],
688        )];
689
690        let result = collect_importable_threads(sessions_by_agent, existing);
691        assert!(result.is_empty());
692    }
693}