thread_import.rs

  1use acp_thread::AgentSessionListRequest;
  2use agent::ThreadStore;
  3use agent_client_protocol as acp;
  4use chrono::Utc;
  5use collections::HashSet;
  6use fs::Fs;
  7use futures::FutureExt as _;
  8use gpui::{
  9    App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
 10    Render, SharedString, Task, WeakEntity, Window,
 11};
 12use notifications::status_toast::{StatusToast, ToastIcon};
 13use project::{AgentId, AgentRegistryStore, AgentServerStore};
 14use ui::{Checkbox, CommonAnimationExt as _, KeyBinding, ListItem, ListItemSpacing, prelude::*};
 15use util::ResultExt;
 16use workspace::{ModalView, MultiWorkspace, Workspace};
 17
 18use crate::{
 19    Agent, AgentPanel,
 20    agent_connection_store::AgentConnectionStore,
 21    thread_metadata_store::{ThreadMetadata, ThreadMetadataStore},
 22};
 23
 24#[derive(Clone)]
 25struct AgentEntry {
 26    agent_id: AgentId,
 27    display_name: SharedString,
 28    icon_path: Option<SharedString>,
 29}
 30
 31pub struct ThreadImportModal {
 32    focus_handle: FocusHandle,
 33    workspace: WeakEntity<Workspace>,
 34    multi_workspace: WeakEntity<MultiWorkspace>,
 35    agent_entries: Vec<AgentEntry>,
 36    unchecked_agents: HashSet<AgentId>,
 37    is_importing: bool,
 38    last_error: Option<SharedString>,
 39}
 40
 41impl ThreadImportModal {
 42    pub fn new(
 43        agent_server_store: Entity<AgentServerStore>,
 44        agent_registry_store: Entity<AgentRegistryStore>,
 45        workspace: WeakEntity<Workspace>,
 46        multi_workspace: WeakEntity<MultiWorkspace>,
 47        _window: &mut Window,
 48        cx: &mut Context<Self>,
 49    ) -> Self {
 50        let agent_entries = agent_server_store
 51            .read(cx)
 52            .external_agents()
 53            .map(|agent_id| {
 54                let display_name = agent_server_store
 55                    .read(cx)
 56                    .agent_display_name(agent_id)
 57                    .or_else(|| {
 58                        agent_registry_store
 59                            .read(cx)
 60                            .agent(agent_id)
 61                            .map(|agent| agent.name().clone())
 62                    })
 63                    .unwrap_or_else(|| agent_id.0.clone());
 64                let icon_path = agent_server_store
 65                    .read(cx)
 66                    .agent_icon(agent_id)
 67                    .or_else(|| {
 68                        agent_registry_store
 69                            .read(cx)
 70                            .agent(agent_id)
 71                            .and_then(|agent| agent.icon_path().cloned())
 72                    });
 73
 74                AgentEntry {
 75                    agent_id: agent_id.clone(),
 76                    display_name,
 77                    icon_path,
 78                }
 79            })
 80            .collect::<Vec<_>>();
 81
 82        Self {
 83            focus_handle: cx.focus_handle(),
 84            workspace,
 85            multi_workspace,
 86            agent_entries,
 87            unchecked_agents: HashSet::default(),
 88            is_importing: false,
 89            last_error: None,
 90        }
 91    }
 92
 93    fn agent_ids(&self) -> Vec<AgentId> {
 94        self.agent_entries
 95            .iter()
 96            .map(|entry| entry.agent_id.clone())
 97            .collect()
 98    }
 99
100    fn set_agent_checked(&mut self, agent_id: AgentId, state: ToggleState, cx: &mut Context<Self>) {
101        match state {
102            ToggleState::Selected => {
103                self.unchecked_agents.remove(&agent_id);
104            }
105            ToggleState::Unselected | ToggleState::Indeterminate => {
106                self.unchecked_agents.insert(agent_id);
107            }
108        }
109        cx.notify();
110    }
111
112    fn toggle_agent_checked(&mut self, agent_id: AgentId, cx: &mut Context<Self>) {
113        if self.unchecked_agents.contains(&agent_id) {
114            self.unchecked_agents.remove(&agent_id);
115        } else {
116            self.unchecked_agents.insert(agent_id);
117        }
118        cx.notify();
119    }
120
121    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
122        cx.emit(DismissEvent);
123    }
124
125    fn import_threads(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
126        if self.is_importing {
127            return;
128        }
129
130        let Some(multi_workspace) = self.multi_workspace.upgrade() else {
131            self.is_importing = false;
132            cx.notify();
133            return;
134        };
135
136        let stores = resolve_agent_connection_stores(&multi_workspace, cx);
137        if stores.is_empty() {
138            log::error!("Did not find any workspaces to import from");
139            self.is_importing = false;
140            cx.notify();
141            return;
142        }
143
144        self.is_importing = true;
145        self.last_error = None;
146        cx.notify();
147
148        let agent_ids = self
149            .agent_ids()
150            .into_iter()
151            .filter(|agent_id| !self.unchecked_agents.contains(agent_id))
152            .collect::<Vec<_>>();
153
154        let existing_sessions = ThreadMetadataStore::global(cx)
155            .read(cx)
156            .entry_ids()
157            .collect::<HashSet<_>>();
158
159        let task = find_threads_to_import(agent_ids, existing_sessions, stores, cx);
160        cx.spawn(async move |this, cx| {
161            let result = task.await;
162            this.update(cx, |this, cx| match result {
163                Ok(threads) => {
164                    let imported_count = threads.len();
165                    ThreadMetadataStore::global(cx)
166                        .update(cx, |store, cx| store.save_all(threads, cx));
167                    this.is_importing = false;
168                    this.last_error = None;
169                    this.show_imported_threads_toast(imported_count, cx);
170                    cx.emit(DismissEvent);
171                }
172                Err(error) => {
173                    this.is_importing = false;
174                    this.last_error = Some(error.to_string().into());
175                    cx.notify();
176                }
177            })
178        })
179        .detach_and_log_err(cx);
180    }
181
182    fn show_imported_threads_toast(&self, imported_count: usize, cx: &mut App) {
183        let status_toast = if imported_count == 0 {
184            StatusToast::new("No threads found to import.", cx, |this, _cx| {
185                this.icon(ToastIcon::new(IconName::Info).color(Color::Info))
186            })
187        } else {
188            let message = if imported_count == 1 {
189                "Imported 1 thread.".to_string()
190            } else {
191                format!("Imported {imported_count} threads.")
192            };
193            StatusToast::new(message, cx, |this, _cx| {
194                this.icon(ToastIcon::new(IconName::Check).color(Color::Success))
195            })
196        };
197
198        self.workspace
199            .update(cx, |workspace, cx| {
200                workspace.toggle_status_toast(status_toast, cx);
201            })
202            .log_err();
203    }
204}
205
206impl EventEmitter<DismissEvent> for ThreadImportModal {}
207
208impl Focusable for ThreadImportModal {
209    fn focus_handle(&self, _cx: &App) -> FocusHandle {
210        self.focus_handle.clone()
211    }
212}
213
214impl ModalView for ThreadImportModal {}
215
216impl Render for ThreadImportModal {
217    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
218        let agent_rows = self
219            .agent_entries
220            .iter()
221            .enumerate()
222            .map(|(ix, entry)| {
223                let is_checked = !self.unchecked_agents.contains(&entry.agent_id);
224
225                ListItem::new(("thread-import-agent", ix))
226                    .inset(true)
227                    .spacing(ListItemSpacing::Sparse)
228                    .start_slot(
229                        Checkbox::new(
230                            ("thread-import-agent-checkbox", ix),
231                            if is_checked {
232                                ToggleState::Selected
233                            } else {
234                                ToggleState::Unselected
235                            },
236                        )
237                        .on_click({
238                            let agent_id = entry.agent_id.clone();
239                            cx.listener(move |this, state: &ToggleState, _window, cx| {
240                                this.set_agent_checked(agent_id.clone(), *state, cx);
241                            })
242                        }),
243                    )
244                    .on_click({
245                        let agent_id = entry.agent_id.clone();
246                        cx.listener(move |this, _event, _window, cx| {
247                            this.toggle_agent_checked(agent_id.clone(), cx);
248                        })
249                    })
250                    .child(
251                        h_flex()
252                            .w_full()
253                            .gap_2()
254                            .child(if let Some(icon_path) = entry.icon_path.clone() {
255                                Icon::from_external_svg(icon_path)
256                                    .color(Color::Muted)
257                                    .size(IconSize::Small)
258                            } else {
259                                Icon::new(IconName::Sparkle)
260                                    .color(Color::Muted)
261                                    .size(IconSize::Small)
262                            })
263                            .child(Label::new(entry.display_name.clone())),
264                    )
265            })
266            .collect::<Vec<_>>();
267
268        let has_agents = !self.agent_entries.is_empty();
269
270        v_flex()
271            .id("thread-import-modal")
272            .key_context("ThreadImportModal")
273            .w(rems(34.))
274            .elevation_3(cx)
275            .overflow_hidden()
276            .track_focus(&self.focus_handle)
277            .on_action(cx.listener(Self::cancel))
278            .on_action(cx.listener(Self::import_threads))
279            .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
280                this.focus_handle.focus(window, cx);
281            }))
282            // Header
283            .child(
284                v_flex()
285                    .p_4()
286                    .pb_2()
287                    .gap_1()
288                    .child(Headline::new("Import Threads").size(HeadlineSize::Small))
289                    .child(
290                        Label::new(
291                            "Select the agents whose threads you'd like to import. \
292                             Imported threads will appear in your thread archive.",
293                        )
294                        .color(Color::Muted)
295                        .size(LabelSize::Small),
296                    ),
297            )
298            // Agent list
299            .child(
300                v_flex()
301                    .id("thread-import-agent-list")
302                    .px_2()
303                    .max_h(rems(20.))
304                    .overflow_y_scroll()
305                    .when(has_agents, |this| this.children(agent_rows))
306                    .when(!has_agents, |this| {
307                        this.child(
308                            div().p_4().child(
309                                Label::new("No ACP agents available.")
310                                    .color(Color::Muted)
311                                    .size(LabelSize::Small),
312                            ),
313                        )
314                    }),
315            )
316            // Footer
317            .child(
318                h_flex()
319                    .w_full()
320                    .p_3()
321                    .gap_2()
322                    .items_center()
323                    .border_t_1()
324                    .border_color(cx.theme().colors().border_variant)
325                    .child(div().flex_1().min_w_0().when_some(
326                        self.last_error.clone(),
327                        |this, error| {
328                            this.child(
329                                Label::new(error)
330                                    .size(LabelSize::Small)
331                                    .color(Color::Error)
332                                    .truncate(),
333                            )
334                        },
335                    ))
336                    .child(
337                        h_flex()
338                            .gap_2()
339                            .items_center()
340                            .when(self.is_importing, |this| {
341                                this.child(
342                                    Icon::new(IconName::ArrowCircle)
343                                        .size(IconSize::Small)
344                                        .color(Color::Muted)
345                                        .with_rotate_animation(2),
346                                )
347                            })
348                            .child(
349                                Button::new("import-threads", "Import Threads")
350                                    .disabled(self.is_importing || !has_agents)
351                                    .key_binding(
352                                        KeyBinding::for_action(&menu::Confirm, cx)
353                                            .map(|kb| kb.size(rems_from_px(12.))),
354                                    )
355                                    .on_click(cx.listener(|this, _, window, cx| {
356                                        this.import_threads(&menu::Confirm, window, cx);
357                                    })),
358                            ),
359                    ),
360            )
361    }
362}
363
364fn resolve_agent_connection_stores(
365    multi_workspace: &Entity<MultiWorkspace>,
366    cx: &App,
367) -> Vec<Entity<AgentConnectionStore>> {
368    let mut stores = Vec::new();
369    let mut included_local_store = false;
370
371    for workspace in multi_workspace.read(cx).workspaces() {
372        let workspace = workspace.read(cx);
373        let project = workspace.project().read(cx);
374
375        // We only want to include scores from one local workspace, since we
376        // know that they live on the same machine
377        let include_store = if project.is_remote() {
378            true
379        } else if project.is_local() && !included_local_store {
380            included_local_store = true;
381            true
382        } else {
383            false
384        };
385
386        if !include_store {
387            continue;
388        }
389
390        if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
391            stores.push(panel.read(cx).connection_store().clone());
392        }
393    }
394
395    stores
396}
397
398fn find_threads_to_import(
399    agent_ids: Vec<AgentId>,
400    existing_sessions: HashSet<acp::SessionId>,
401    stores: Vec<Entity<AgentConnectionStore>>,
402    cx: &mut App,
403) -> Task<anyhow::Result<Vec<ThreadMetadata>>> {
404    let mut wait_for_connection_tasks = Vec::new();
405
406    for store in stores {
407        for agent_id in agent_ids.clone() {
408            let agent = Agent::from(agent_id.clone());
409            let server = agent.server(<dyn Fs>::global(cx), ThreadStore::global(cx));
410            let entry = store.update(cx, |store, cx| store.request_connection(agent, server, cx));
411            wait_for_connection_tasks
412                .push(entry.read(cx).wait_for_connection().map(|s| (agent_id, s)));
413        }
414    }
415
416    let mut session_list_tasks = Vec::new();
417    cx.spawn(async move |cx| {
418        let results = futures::future::join_all(wait_for_connection_tasks).await;
419        for (agent, result) in results {
420            let Some(state) = result.log_err() else {
421                continue;
422            };
423            let Some(list) = cx.update(|cx| state.connection.session_list(cx)) else {
424                continue;
425            };
426            let task = cx.update(|cx| {
427                list.list_sessions(AgentSessionListRequest::default(), cx)
428                    .map(|r| (agent, r))
429            });
430            session_list_tasks.push(task);
431        }
432
433        let mut sessions_by_agent = Vec::new();
434        let results = futures::future::join_all(session_list_tasks).await;
435        for (agent_id, result) in results {
436            let Some(response) = result.log_err() else {
437                continue;
438            };
439            sessions_by_agent.push((agent_id, response.sessions));
440        }
441
442        Ok(collect_importable_threads(
443            sessions_by_agent,
444            existing_sessions,
445        ))
446    })
447}
448
449fn collect_importable_threads(
450    sessions_by_agent: Vec<(AgentId, Vec<acp_thread::AgentSessionInfo>)>,
451    mut existing_sessions: HashSet<acp::SessionId>,
452) -> Vec<ThreadMetadata> {
453    let mut to_insert = Vec::new();
454    for (agent_id, sessions) in sessions_by_agent {
455        for session in sessions {
456            if !existing_sessions.insert(session.session_id.clone()) {
457                continue;
458            }
459            let Some(folder_paths) = session.work_dirs else {
460                continue;
461            };
462            to_insert.push(ThreadMetadata {
463                session_id: session.session_id,
464                agent_id: agent_id.clone(),
465                title: session
466                    .title
467                    .unwrap_or_else(|| crate::DEFAULT_THREAD_TITLE.into()),
468                updated_at: session.updated_at.unwrap_or_else(|| Utc::now()),
469                created_at: session.created_at,
470                folder_paths,
471                archived: true,
472            });
473        }
474    }
475    to_insert
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use acp_thread::AgentSessionInfo;
482    use chrono::Utc;
483    use std::path::Path;
484    use workspace::PathList;
485
486    fn make_session(
487        session_id: &str,
488        title: Option<&str>,
489        work_dirs: Option<PathList>,
490        updated_at: Option<chrono::DateTime<Utc>>,
491        created_at: Option<chrono::DateTime<Utc>>,
492    ) -> AgentSessionInfo {
493        AgentSessionInfo {
494            session_id: acp::SessionId::new(session_id),
495            title: title.map(|t| SharedString::from(t.to_string())),
496            work_dirs,
497            updated_at,
498            created_at,
499            meta: None,
500        }
501    }
502
503    #[test]
504    fn test_collect_skips_sessions_already_in_existing_set() {
505        let existing = HashSet::from_iter(vec![acp::SessionId::new("existing-1")]);
506        let paths = PathList::new(&[Path::new("/project")]);
507
508        let sessions_by_agent = vec![(
509            AgentId::new("agent-a"),
510            vec![
511                make_session(
512                    "existing-1",
513                    Some("Already There"),
514                    Some(paths.clone()),
515                    None,
516                    None,
517                ),
518                make_session("new-1", Some("Brand New"), Some(paths), None, None),
519            ],
520        )];
521
522        let result = collect_importable_threads(sessions_by_agent, existing);
523
524        assert_eq!(result.len(), 1);
525        assert_eq!(result[0].session_id.0.as_ref(), "new-1");
526        assert_eq!(result[0].title.as_ref(), "Brand New");
527    }
528
529    #[test]
530    fn test_collect_skips_sessions_without_work_dirs() {
531        let existing = HashSet::default();
532        let paths = PathList::new(&[Path::new("/project")]);
533
534        let sessions_by_agent = vec![(
535            AgentId::new("agent-a"),
536            vec![
537                make_session("has-dirs", Some("With Dirs"), Some(paths), None, None),
538                make_session("no-dirs", Some("No Dirs"), None, None, None),
539            ],
540        )];
541
542        let result = collect_importable_threads(sessions_by_agent, existing);
543
544        assert_eq!(result.len(), 1);
545        assert_eq!(result[0].session_id.0.as_ref(), "has-dirs");
546    }
547
548    #[test]
549    fn test_collect_marks_all_imported_threads_as_archived() {
550        let existing = HashSet::default();
551        let paths = PathList::new(&[Path::new("/project")]);
552
553        let sessions_by_agent = vec![(
554            AgentId::new("agent-a"),
555            vec![
556                make_session("s1", Some("Thread 1"), Some(paths.clone()), None, None),
557                make_session("s2", Some("Thread 2"), Some(paths), None, None),
558            ],
559        )];
560
561        let result = collect_importable_threads(sessions_by_agent, existing);
562
563        assert_eq!(result.len(), 2);
564        assert!(result.iter().all(|t| t.archived));
565    }
566
567    #[test]
568    fn test_collect_assigns_correct_agent_id_per_session() {
569        let existing = HashSet::default();
570        let paths = PathList::new(&[Path::new("/project")]);
571
572        let sessions_by_agent = vec![
573            (
574                AgentId::new("agent-a"),
575                vec![make_session(
576                    "s1",
577                    Some("From A"),
578                    Some(paths.clone()),
579                    None,
580                    None,
581                )],
582            ),
583            (
584                AgentId::new("agent-b"),
585                vec![make_session("s2", Some("From B"), Some(paths), None, None)],
586            ),
587        ];
588
589        let result = collect_importable_threads(sessions_by_agent, existing);
590
591        assert_eq!(result.len(), 2);
592        let s1 = result
593            .iter()
594            .find(|t| t.session_id.0.as_ref() == "s1")
595            .unwrap();
596        let s2 = result
597            .iter()
598            .find(|t| t.session_id.0.as_ref() == "s2")
599            .unwrap();
600        assert_eq!(s1.agent_id.as_ref(), "agent-a");
601        assert_eq!(s2.agent_id.as_ref(), "agent-b");
602    }
603
604    #[test]
605    fn test_collect_deduplicates_across_agents() {
606        let existing = HashSet::default();
607        let paths = PathList::new(&[Path::new("/project")]);
608
609        let sessions_by_agent = vec![
610            (
611                AgentId::new("agent-a"),
612                vec![make_session(
613                    "shared-session",
614                    Some("From A"),
615                    Some(paths.clone()),
616                    None,
617                    None,
618                )],
619            ),
620            (
621                AgentId::new("agent-b"),
622                vec![make_session(
623                    "shared-session",
624                    Some("From B"),
625                    Some(paths),
626                    None,
627                    None,
628                )],
629            ),
630        ];
631
632        let result = collect_importable_threads(sessions_by_agent, existing);
633
634        assert_eq!(result.len(), 1);
635        assert_eq!(result[0].session_id.0.as_ref(), "shared-session");
636        assert_eq!(
637            result[0].agent_id.as_ref(),
638            "agent-a",
639            "first agent encountered should win"
640        );
641    }
642
643    #[test]
644    fn test_collect_all_existing_returns_empty() {
645        let paths = PathList::new(&[Path::new("/project")]);
646        let existing =
647            HashSet::from_iter(vec![acp::SessionId::new("s1"), acp::SessionId::new("s2")]);
648
649        let sessions_by_agent = vec![(
650            AgentId::new("agent-a"),
651            vec![
652                make_session("s1", Some("T1"), Some(paths.clone()), None, None),
653                make_session("s2", Some("T2"), Some(paths), None, None),
654            ],
655        )];
656
657        let result = collect_importable_threads(sessions_by_agent, existing);
658        assert!(result.is_empty());
659    }
660}