thread_import.rs

  1use acp_thread::AgentSessionListRequest;
  2use agent::ThreadStore;
  3use agent_client_protocol as acp;
  4use chrono::Utc;
  5use collections::HashSet;
  6use db::kvp::Dismissable;
  7use fs::Fs;
  8use futures::FutureExt as _;
  9use gpui::{
 10    App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
 11    Render, SharedString, Task, WeakEntity, Window,
 12};
 13use notifications::status_toast::{StatusToast, ToastIcon};
 14use project::{AgentId, AgentRegistryStore, AgentServerStore};
 15use ui::{
 16    Checkbox, KeyBinding, ListItem, ListItemSpacing, Modal, ModalFooter, ModalHeader, Section,
 17    prelude::*,
 18};
 19use util::ResultExt;
 20use workspace::{ModalView, MultiWorkspace, Workspace};
 21
 22use crate::{
 23    Agent, AgentPanel,
 24    agent_connection_store::AgentConnectionStore,
 25    thread_metadata_store::{ThreadMetadata, ThreadMetadataStore},
 26};
 27
 28pub struct AcpThreadImportOnboarding;
 29
 30impl AcpThreadImportOnboarding {
 31    pub fn dismissed(cx: &App) -> bool {
 32        <Self as Dismissable>::dismissed(cx)
 33    }
 34
 35    pub fn dismiss(cx: &mut App) {
 36        <Self as Dismissable>::set_dismissed(true, cx);
 37    }
 38}
 39
 40impl Dismissable for AcpThreadImportOnboarding {
 41    const KEY: &'static str = "dismissed-acp-thread-import";
 42}
 43
 44#[derive(Clone)]
 45struct AgentEntry {
 46    agent_id: AgentId,
 47    display_name: SharedString,
 48    icon_path: Option<SharedString>,
 49}
 50
 51pub struct ThreadImportModal {
 52    focus_handle: FocusHandle,
 53    workspace: WeakEntity<Workspace>,
 54    multi_workspace: WeakEntity<MultiWorkspace>,
 55    agent_entries: Vec<AgentEntry>,
 56    unchecked_agents: HashSet<AgentId>,
 57    selected_index: Option<usize>,
 58    is_importing: bool,
 59    last_error: Option<SharedString>,
 60}
 61
 62impl ThreadImportModal {
 63    pub fn new(
 64        agent_server_store: Entity<AgentServerStore>,
 65        agent_registry_store: Entity<AgentRegistryStore>,
 66        workspace: WeakEntity<Workspace>,
 67        multi_workspace: WeakEntity<MultiWorkspace>,
 68        _window: &mut Window,
 69        cx: &mut Context<Self>,
 70    ) -> Self {
 71        AcpThreadImportOnboarding::dismiss(cx);
 72
 73        let agent_entries = agent_server_store
 74            .read(cx)
 75            .external_agents()
 76            .map(|agent_id| {
 77                let display_name = agent_server_store
 78                    .read(cx)
 79                    .agent_display_name(agent_id)
 80                    .or_else(|| {
 81                        agent_registry_store
 82                            .read(cx)
 83                            .agent(agent_id)
 84                            .map(|agent| agent.name().clone())
 85                    })
 86                    .unwrap_or_else(|| agent_id.0.clone());
 87                let icon_path = agent_server_store
 88                    .read(cx)
 89                    .agent_icon(agent_id)
 90                    .or_else(|| {
 91                        agent_registry_store
 92                            .read(cx)
 93                            .agent(agent_id)
 94                            .and_then(|agent| agent.icon_path().cloned())
 95                    });
 96
 97                AgentEntry {
 98                    agent_id: agent_id.clone(),
 99                    display_name,
100                    icon_path,
101                }
102            })
103            .collect::<Vec<_>>();
104
105        Self {
106            focus_handle: cx.focus_handle(),
107            workspace,
108            multi_workspace,
109            agent_entries,
110            unchecked_agents: HashSet::default(),
111            selected_index: None,
112            is_importing: false,
113            last_error: None,
114        }
115    }
116
117    fn agent_ids(&self) -> Vec<AgentId> {
118        self.agent_entries
119            .iter()
120            .map(|entry| entry.agent_id.clone())
121            .collect()
122    }
123
124    fn set_agent_checked(&mut self, agent_id: AgentId, state: ToggleState, cx: &mut Context<Self>) {
125        match state {
126            ToggleState::Selected => {
127                self.unchecked_agents.remove(&agent_id);
128            }
129            ToggleState::Unselected | ToggleState::Indeterminate => {
130                self.unchecked_agents.insert(agent_id);
131            }
132        }
133        cx.notify();
134    }
135
136    fn toggle_agent_checked(&mut self, agent_id: AgentId, cx: &mut Context<Self>) {
137        if self.unchecked_agents.contains(&agent_id) {
138            self.unchecked_agents.remove(&agent_id);
139        } else {
140            self.unchecked_agents.insert(agent_id);
141        }
142        cx.notify();
143    }
144
145    fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context<Self>) {
146        if self.agent_entries.is_empty() {
147            return;
148        }
149        self.selected_index = Some(match self.selected_index {
150            Some(ix) if ix + 1 >= self.agent_entries.len() => 0,
151            Some(ix) => ix + 1,
152            None => 0,
153        });
154        cx.notify();
155    }
156
157    fn select_previous(
158        &mut self,
159        _: &menu::SelectPrevious,
160        _window: &mut Window,
161        cx: &mut Context<Self>,
162    ) {
163        if self.agent_entries.is_empty() {
164            return;
165        }
166        self.selected_index = Some(match self.selected_index {
167            Some(0) => self.agent_entries.len() - 1,
168            Some(ix) => ix - 1,
169            None => self.agent_entries.len() - 1,
170        });
171        cx.notify();
172    }
173
174    fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
175        if let Some(ix) = self.selected_index {
176            if let Some(entry) = self.agent_entries.get(ix) {
177                self.toggle_agent_checked(entry.agent_id.clone(), cx);
178            }
179        }
180    }
181
182    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
183        cx.emit(DismissEvent);
184    }
185
186    fn import_threads(
187        &mut self,
188        _: &menu::SecondaryConfirm,
189        _: &mut Window,
190        cx: &mut Context<Self>,
191    ) {
192        if self.is_importing {
193            return;
194        }
195
196        let Some(multi_workspace) = self.multi_workspace.upgrade() else {
197            self.is_importing = false;
198            cx.notify();
199            return;
200        };
201
202        let stores = resolve_agent_connection_stores(&multi_workspace, cx);
203        if stores.is_empty() {
204            log::error!("Did not find any workspaces to import from");
205            self.is_importing = false;
206            cx.notify();
207            return;
208        }
209
210        self.is_importing = true;
211        self.last_error = None;
212        cx.notify();
213
214        let agent_ids = self
215            .agent_ids()
216            .into_iter()
217            .filter(|agent_id| !self.unchecked_agents.contains(agent_id))
218            .collect::<Vec<_>>();
219
220        let existing_sessions = ThreadMetadataStore::global(cx)
221            .read(cx)
222            .entry_ids()
223            .collect::<HashSet<_>>();
224
225        let task = find_threads_to_import(agent_ids, existing_sessions, stores, cx);
226        cx.spawn(async move |this, cx| {
227            let result = task.await;
228            this.update(cx, |this, cx| match result {
229                Ok(threads) => {
230                    let imported_count = threads.len();
231                    ThreadMetadataStore::global(cx)
232                        .update(cx, |store, cx| store.save_all(threads, cx));
233                    this.is_importing = false;
234                    this.last_error = None;
235                    this.show_imported_threads_toast(imported_count, cx);
236                    cx.emit(DismissEvent);
237                }
238                Err(error) => {
239                    this.is_importing = false;
240                    this.last_error = Some(error.to_string().into());
241                    cx.notify();
242                }
243            })
244        })
245        .detach_and_log_err(cx);
246    }
247
248    fn show_imported_threads_toast(&self, imported_count: usize, cx: &mut App) {
249        let status_toast = if imported_count == 0 {
250            StatusToast::new("No threads found to import.", cx, |this, _cx| {
251                this.icon(ToastIcon::new(IconName::Info).color(Color::Muted))
252                    .dismiss_button(true)
253            })
254        } else {
255            let message = if imported_count == 1 {
256                "Imported 1 thread.".to_string()
257            } else {
258                format!("Imported {imported_count} threads.")
259            };
260            StatusToast::new(message, cx, |this, _cx| {
261                this.icon(ToastIcon::new(IconName::Check).color(Color::Success))
262                    .dismiss_button(true)
263            })
264        };
265
266        self.workspace
267            .update(cx, |workspace, cx| {
268                workspace.toggle_status_toast(status_toast, cx);
269            })
270            .log_err();
271    }
272}
273
274impl EventEmitter<DismissEvent> for ThreadImportModal {}
275
276impl Focusable for ThreadImportModal {
277    fn focus_handle(&self, _cx: &App) -> FocusHandle {
278        self.focus_handle.clone()
279    }
280}
281
282impl ModalView for ThreadImportModal {}
283
284impl Render for ThreadImportModal {
285    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
286        let agent_rows = self
287            .agent_entries
288            .iter()
289            .enumerate()
290            .map(|(ix, entry)| {
291                let is_checked = !self.unchecked_agents.contains(&entry.agent_id);
292                let is_focused = self.selected_index == Some(ix);
293
294                ListItem::new(("thread-import-agent", ix))
295                    .rounded()
296                    .spacing(ListItemSpacing::Sparse)
297                    .focused(is_focused)
298                    .child(
299                        h_flex()
300                            .w_full()
301                            .gap_2()
302                            .when(!is_checked, |this| this.opacity(0.6))
303                            .child(if let Some(icon_path) = entry.icon_path.clone() {
304                                Icon::from_external_svg(icon_path)
305                                    .color(Color::Muted)
306                                    .size(IconSize::Small)
307                            } else {
308                                Icon::new(IconName::Sparkle)
309                                    .color(Color::Muted)
310                                    .size(IconSize::Small)
311                            })
312                            .child(Label::new(entry.display_name.clone())),
313                    )
314                    .end_slot(
315                        Checkbox::new(
316                            ("thread-import-agent-checkbox", ix),
317                            if is_checked {
318                                ToggleState::Selected
319                            } else {
320                                ToggleState::Unselected
321                            },
322                        )
323                        .on_click({
324                            let agent_id = entry.agent_id.clone();
325                            cx.listener(move |this, state: &ToggleState, _window, cx| {
326                                this.set_agent_checked(agent_id.clone(), *state, cx);
327                            })
328                        }),
329                    )
330                    .on_click({
331                        let agent_id = entry.agent_id.clone();
332                        cx.listener(move |this, _event, _window, cx| {
333                            this.toggle_agent_checked(agent_id.clone(), cx);
334                        })
335                    })
336            })
337            .collect::<Vec<_>>();
338
339        let has_agents = !self.agent_entries.is_empty();
340        let disabled_import_thread = self.is_importing
341            || !has_agents
342            || self.unchecked_agents.len() == self.agent_entries.len();
343
344        v_flex()
345            .id("thread-import-modal")
346            .key_context("ThreadImportModal")
347            .w(rems(34.))
348            .elevation_3(cx)
349            .overflow_hidden()
350            .track_focus(&self.focus_handle)
351            .on_action(cx.listener(Self::cancel))
352            .on_action(cx.listener(Self::confirm))
353            .on_action(cx.listener(Self::select_next))
354            .on_action(cx.listener(Self::select_previous))
355            .on_action(cx.listener(Self::import_threads))
356            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
357                this.focus_handle.focus(window, cx);
358            }))
359            .child(
360                Modal::new("import-threads", None)
361                    .header(
362                        ModalHeader::new()
363                            .headline("Import ACP Threads")
364                            .description(
365                                "Import threads from your ACP agents — whether started in Zed or another client. \
366                                Choose which agents to include, and their threads will appear in your archive."
367                            )
368                            .show_dismiss_button(true),
369
370                    )
371                    .section(
372                        Section::new().child(
373                            v_flex()
374                                .id("thread-import-agent-list")
375                                .max_h(rems_from_px(320.))
376                                .pb_2()
377                                .overflow_y_scroll()
378                                .when(has_agents, |this| this.children(agent_rows))
379                                .when(!has_agents, |this| {
380                                    this.child(
381                                        Label::new("No ACP agents available.")
382                                            .color(Color::Muted)
383                                            .size(LabelSize::Small),
384                                    )
385                                }),
386                        ),
387                    )
388                    .footer(
389                        ModalFooter::new()
390                            .when_some(self.last_error.clone(), |this, error| {
391                                this.start_slot(
392                                    Label::new(error)
393                                        .size(LabelSize::Small)
394                                        .color(Color::Error)
395                                        .truncate(),
396                                )
397                            })
398                            .end_slot(
399                                Button::new("import-threads", "Import Threads")
400                                    .loading(self.is_importing)
401                                    .disabled(disabled_import_thread)
402                                    .key_binding(
403                                        KeyBinding::for_action(&menu::SecondaryConfirm, cx)
404                                            .map(|kb| kb.size(rems_from_px(12.))),
405                                    )
406                                    .on_click(cx.listener(|this, _, window, cx| {
407                                        this.import_threads(&menu::SecondaryConfirm, window, cx);
408                                    })),
409                            ),
410                    ),
411            )
412    }
413}
414
415fn resolve_agent_connection_stores(
416    multi_workspace: &Entity<MultiWorkspace>,
417    cx: &App,
418) -> Vec<Entity<AgentConnectionStore>> {
419    let mut stores = Vec::new();
420    let mut included_local_store = false;
421
422    for workspace in multi_workspace.read(cx).workspaces() {
423        let workspace = workspace.read(cx);
424        let project = workspace.project().read(cx);
425
426        // We only want to include scores from one local workspace, since we
427        // know that they live on the same machine
428        let include_store = if project.is_remote() {
429            true
430        } else if project.is_local() && !included_local_store {
431            included_local_store = true;
432            true
433        } else {
434            false
435        };
436
437        if !include_store {
438            continue;
439        }
440
441        if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
442            stores.push(panel.read(cx).connection_store().clone());
443        }
444    }
445
446    stores
447}
448
449fn find_threads_to_import(
450    agent_ids: Vec<AgentId>,
451    existing_sessions: HashSet<acp::SessionId>,
452    stores: Vec<Entity<AgentConnectionStore>>,
453    cx: &mut App,
454) -> Task<anyhow::Result<Vec<ThreadMetadata>>> {
455    let mut wait_for_connection_tasks = Vec::new();
456
457    for store in stores {
458        for agent_id in agent_ids.clone() {
459            let agent = Agent::from(agent_id.clone());
460            let server = agent.server(<dyn Fs>::global(cx), ThreadStore::global(cx));
461            let entry = store.update(cx, |store, cx| store.request_connection(agent, server, cx));
462            wait_for_connection_tasks
463                .push(entry.read(cx).wait_for_connection().map(|s| (agent_id, s)));
464        }
465    }
466
467    let mut session_list_tasks = Vec::new();
468    cx.spawn(async move |cx| {
469        let results = futures::future::join_all(wait_for_connection_tasks).await;
470        for (agent, result) in results {
471            let Some(state) = result.log_err() else {
472                continue;
473            };
474            let Some(list) = cx.update(|cx| state.connection.session_list(cx)) else {
475                continue;
476            };
477            let task = cx.update(|cx| {
478                list.list_sessions(AgentSessionListRequest::default(), cx)
479                    .map(|r| (agent, r))
480            });
481            session_list_tasks.push(task);
482        }
483
484        let mut sessions_by_agent = Vec::new();
485        let results = futures::future::join_all(session_list_tasks).await;
486        for (agent_id, result) in results {
487            let Some(response) = result.log_err() else {
488                continue;
489            };
490            sessions_by_agent.push((agent_id, response.sessions));
491        }
492
493        Ok(collect_importable_threads(
494            sessions_by_agent,
495            existing_sessions,
496        ))
497    })
498}
499
500fn collect_importable_threads(
501    sessions_by_agent: Vec<(AgentId, Vec<acp_thread::AgentSessionInfo>)>,
502    mut existing_sessions: HashSet<acp::SessionId>,
503) -> Vec<ThreadMetadata> {
504    let mut to_insert = Vec::new();
505    for (agent_id, sessions) in sessions_by_agent {
506        for session in sessions {
507            if !existing_sessions.insert(session.session_id.clone()) {
508                continue;
509            }
510            let Some(folder_paths) = session.work_dirs else {
511                continue;
512            };
513            to_insert.push(ThreadMetadata {
514                session_id: session.session_id,
515                agent_id: agent_id.clone(),
516                title: session
517                    .title
518                    .unwrap_or_else(|| crate::DEFAULT_THREAD_TITLE.into()),
519                updated_at: session.updated_at.unwrap_or_else(|| Utc::now()),
520                created_at: session.created_at,
521                folder_paths,
522                archived: true,
523            });
524        }
525    }
526    to_insert
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use acp_thread::AgentSessionInfo;
533    use chrono::Utc;
534    use std::path::Path;
535    use workspace::PathList;
536
537    fn make_session(
538        session_id: &str,
539        title: Option<&str>,
540        work_dirs: Option<PathList>,
541        updated_at: Option<chrono::DateTime<Utc>>,
542        created_at: Option<chrono::DateTime<Utc>>,
543    ) -> AgentSessionInfo {
544        AgentSessionInfo {
545            session_id: acp::SessionId::new(session_id),
546            title: title.map(|t| SharedString::from(t.to_string())),
547            work_dirs,
548            updated_at,
549            created_at,
550            meta: None,
551        }
552    }
553
554    #[test]
555    fn test_collect_skips_sessions_already_in_existing_set() {
556        let existing = HashSet::from_iter(vec![acp::SessionId::new("existing-1")]);
557        let paths = PathList::new(&[Path::new("/project")]);
558
559        let sessions_by_agent = vec![(
560            AgentId::new("agent-a"),
561            vec![
562                make_session(
563                    "existing-1",
564                    Some("Already There"),
565                    Some(paths.clone()),
566                    None,
567                    None,
568                ),
569                make_session("new-1", Some("Brand New"), Some(paths), None, None),
570            ],
571        )];
572
573        let result = collect_importable_threads(sessions_by_agent, existing);
574
575        assert_eq!(result.len(), 1);
576        assert_eq!(result[0].session_id.0.as_ref(), "new-1");
577        assert_eq!(result[0].title.as_ref(), "Brand New");
578    }
579
580    #[test]
581    fn test_collect_skips_sessions_without_work_dirs() {
582        let existing = HashSet::default();
583        let paths = PathList::new(&[Path::new("/project")]);
584
585        let sessions_by_agent = vec![(
586            AgentId::new("agent-a"),
587            vec![
588                make_session("has-dirs", Some("With Dirs"), Some(paths), None, None),
589                make_session("no-dirs", Some("No Dirs"), None, None, None),
590            ],
591        )];
592
593        let result = collect_importable_threads(sessions_by_agent, existing);
594
595        assert_eq!(result.len(), 1);
596        assert_eq!(result[0].session_id.0.as_ref(), "has-dirs");
597    }
598
599    #[test]
600    fn test_collect_marks_all_imported_threads_as_archived() {
601        let existing = HashSet::default();
602        let paths = PathList::new(&[Path::new("/project")]);
603
604        let sessions_by_agent = vec![(
605            AgentId::new("agent-a"),
606            vec![
607                make_session("s1", Some("Thread 1"), Some(paths.clone()), None, None),
608                make_session("s2", Some("Thread 2"), Some(paths), None, None),
609            ],
610        )];
611
612        let result = collect_importable_threads(sessions_by_agent, existing);
613
614        assert_eq!(result.len(), 2);
615        assert!(result.iter().all(|t| t.archived));
616    }
617
618    #[test]
619    fn test_collect_assigns_correct_agent_id_per_session() {
620        let existing = HashSet::default();
621        let paths = PathList::new(&[Path::new("/project")]);
622
623        let sessions_by_agent = vec![
624            (
625                AgentId::new("agent-a"),
626                vec![make_session(
627                    "s1",
628                    Some("From A"),
629                    Some(paths.clone()),
630                    None,
631                    None,
632                )],
633            ),
634            (
635                AgentId::new("agent-b"),
636                vec![make_session("s2", Some("From B"), Some(paths), None, None)],
637            ),
638        ];
639
640        let result = collect_importable_threads(sessions_by_agent, existing);
641
642        assert_eq!(result.len(), 2);
643        let s1 = result
644            .iter()
645            .find(|t| t.session_id.0.as_ref() == "s1")
646            .unwrap();
647        let s2 = result
648            .iter()
649            .find(|t| t.session_id.0.as_ref() == "s2")
650            .unwrap();
651        assert_eq!(s1.agent_id.as_ref(), "agent-a");
652        assert_eq!(s2.agent_id.as_ref(), "agent-b");
653    }
654
655    #[test]
656    fn test_collect_deduplicates_across_agents() {
657        let existing = HashSet::default();
658        let paths = PathList::new(&[Path::new("/project")]);
659
660        let sessions_by_agent = vec![
661            (
662                AgentId::new("agent-a"),
663                vec![make_session(
664                    "shared-session",
665                    Some("From A"),
666                    Some(paths.clone()),
667                    None,
668                    None,
669                )],
670            ),
671            (
672                AgentId::new("agent-b"),
673                vec![make_session(
674                    "shared-session",
675                    Some("From B"),
676                    Some(paths),
677                    None,
678                    None,
679                )],
680            ),
681        ];
682
683        let result = collect_importable_threads(sessions_by_agent, existing);
684
685        assert_eq!(result.len(), 1);
686        assert_eq!(result[0].session_id.0.as_ref(), "shared-session");
687        assert_eq!(
688            result[0].agent_id.as_ref(),
689            "agent-a",
690            "first agent encountered should win"
691        );
692    }
693
694    #[test]
695    fn test_collect_all_existing_returns_empty() {
696        let paths = PathList::new(&[Path::new("/project")]);
697        let existing =
698            HashSet::from_iter(vec![acp::SessionId::new("s1"), acp::SessionId::new("s2")]);
699
700        let sessions_by_agent = vec![(
701            AgentId::new("agent-a"),
702            vec![
703                make_session("s1", Some("T1"), Some(paths.clone()), None, None),
704                make_session("s2", Some("T2"), Some(paths), None, None),
705            ],
706        )];
707
708        let result = collect_importable_threads(sessions_by_agent, existing);
709        assert!(result.is_empty());
710    }
711}