symbol_context_picker.rs

  1use std::cmp::Reverse;
  2use std::sync::Arc;
  3use std::sync::atomic::AtomicBool;
  4
  5use anyhow::Result;
  6use fuzzy::{StringMatch, StringMatchCandidate};
  7use gpui::{
  8    App, AppContext, DismissEvent, Entity, FocusHandle, Focusable, Stateful, Task, WeakEntity,
  9};
 10use ordered_float::OrderedFloat;
 11use picker::{Picker, PickerDelegate};
 12use project::{DocumentSymbol, Symbol};
 13use ui::{ListItem, prelude::*};
 14use util::ResultExt as _;
 15use workspace::Workspace;
 16
 17use crate::context_picker::ContextPicker;
 18use agent::context::AgentContextHandle;
 19use agent::context_store::ContextStore;
 20
 21pub struct SymbolContextPicker {
 22    picker: Entity<Picker<SymbolContextPickerDelegate>>,
 23}
 24
 25impl SymbolContextPicker {
 26    pub fn new(
 27        context_picker: WeakEntity<ContextPicker>,
 28        workspace: WeakEntity<Workspace>,
 29        context_store: WeakEntity<ContextStore>,
 30        window: &mut Window,
 31        cx: &mut Context<Self>,
 32    ) -> Self {
 33        let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store);
 34        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
 35
 36        Self { picker }
 37    }
 38}
 39
 40impl Focusable for SymbolContextPicker {
 41    fn focus_handle(&self, cx: &App) -> FocusHandle {
 42        self.picker.focus_handle(cx)
 43    }
 44}
 45
 46impl Render for SymbolContextPicker {
 47    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 48        self.picker.clone()
 49    }
 50}
 51
 52pub struct SymbolContextPickerDelegate {
 53    context_picker: WeakEntity<ContextPicker>,
 54    workspace: WeakEntity<Workspace>,
 55    context_store: WeakEntity<ContextStore>,
 56    matches: Vec<SymbolEntry>,
 57    selected_index: usize,
 58}
 59
 60impl SymbolContextPickerDelegate {
 61    pub fn new(
 62        context_picker: WeakEntity<ContextPicker>,
 63        workspace: WeakEntity<Workspace>,
 64        context_store: WeakEntity<ContextStore>,
 65    ) -> Self {
 66        Self {
 67            context_picker,
 68            workspace,
 69            context_store,
 70            matches: Vec::new(),
 71            selected_index: 0,
 72        }
 73    }
 74}
 75
 76impl PickerDelegate for SymbolContextPickerDelegate {
 77    type ListItem = ListItem;
 78
 79    fn match_count(&self) -> usize {
 80        self.matches.len()
 81    }
 82
 83    fn selected_index(&self) -> usize {
 84        self.selected_index
 85    }
 86
 87    fn set_selected_index(
 88        &mut self,
 89        ix: usize,
 90        _window: &mut Window,
 91        _cx: &mut Context<Picker<Self>>,
 92    ) {
 93        self.selected_index = ix;
 94    }
 95
 96    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
 97        "Search symbols…".into()
 98    }
 99
100    fn update_matches(
101        &mut self,
102        query: String,
103        window: &mut Window,
104        cx: &mut Context<Picker<Self>>,
105    ) -> Task<()> {
106        let Some(workspace) = self.workspace.upgrade() else {
107            return Task::ready(());
108        };
109
110        let search_task = search_symbols(query, Arc::<AtomicBool>::default(), &workspace, cx);
111        let context_store = self.context_store.clone();
112        cx.spawn_in(window, async move |this, cx| {
113            let symbols = search_task.await;
114
115            let symbol_entries = context_store
116                .read_with(cx, |context_store, cx| {
117                    compute_symbol_entries(symbols, context_store, cx)
118                })
119                .log_err()
120                .unwrap_or_default();
121
122            this.update(cx, |this, _cx| {
123                this.delegate.matches = symbol_entries;
124            })
125            .log_err();
126        })
127    }
128
129    fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
130        let Some(mat) = self.matches.get(self.selected_index) else {
131            return;
132        };
133        let Some(workspace) = self.workspace.upgrade() else {
134            return;
135        };
136
137        let add_symbol_task = add_symbol(
138            mat.symbol.clone(),
139            true,
140            workspace,
141            self.context_store.clone(),
142            cx,
143        );
144
145        let selected_index = self.selected_index;
146        cx.spawn(async move |this, cx| {
147            let (_, included) = add_symbol_task.await?;
148            this.update(cx, |this, _| {
149                if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
150                    mat.is_included = included;
151                }
152            })
153        })
154        .detach_and_log_err(cx);
155    }
156
157    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
158        self.context_picker
159            .update(cx, |_, cx| {
160                cx.emit(DismissEvent);
161            })
162            .ok();
163    }
164
165    fn render_match(
166        &self,
167        ix: usize,
168        selected: bool,
169        _window: &mut Window,
170        _: &mut Context<Picker<Self>>,
171    ) -> Option<Self::ListItem> {
172        let mat = &self.matches.get(ix)?;
173
174        Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
175            render_symbol_context_entry(ElementId::named_usize("symbol-ctx-picker", ix), mat),
176        ))
177    }
178}
179
180pub(crate) struct SymbolEntry {
181    pub symbol: Symbol,
182    pub is_included: bool,
183}
184
185pub(crate) fn add_symbol(
186    symbol: Symbol,
187    remove_if_exists: bool,
188    workspace: Entity<Workspace>,
189    context_store: WeakEntity<ContextStore>,
190    cx: &mut App,
191) -> Task<Result<(Option<AgentContextHandle>, bool)>> {
192    let project = workspace.read(cx).project().clone();
193    let open_buffer_task = project.update(cx, |project, cx| {
194        project.open_buffer(symbol.path.clone(), cx)
195    });
196    cx.spawn(async move |cx| {
197        let buffer = open_buffer_task.await?;
198        let document_symbols = project
199            .update(cx, |project, cx| project.document_symbols(&buffer, cx))?
200            .await?;
201
202        // Try to find a matching document symbol. Document symbols include
203        // not only the symbol itself (e.g. function name), but they also
204        // include the context that they contain (e.g. function body).
205        let (name, range, enclosing_range) = if let Some(DocumentSymbol {
206            name,
207            range,
208            selection_range,
209            ..
210        }) =
211            find_matching_symbol(&symbol, document_symbols.as_slice())
212        {
213            (name, selection_range, range)
214        } else {
215            // If we do not find a matching document symbol, fall back to
216            // just the symbol itself
217            (symbol.name, symbol.range.clone(), symbol.range)
218        };
219
220        let (range, enclosing_range) = buffer.read_with(cx, |buffer, _| {
221            (
222                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
223                buffer.anchor_after(enclosing_range.start)
224                    ..buffer.anchor_before(enclosing_range.end),
225            )
226        })?;
227
228        context_store.update(cx, move |context_store, cx| {
229            context_store.add_symbol(
230                buffer,
231                name.into(),
232                range,
233                enclosing_range,
234                remove_if_exists,
235                cx,
236            )
237        })
238    })
239}
240
241fn find_matching_symbol(symbol: &Symbol, candidates: &[DocumentSymbol]) -> Option<DocumentSymbol> {
242    let mut candidates = candidates.iter();
243    let mut candidate = candidates.next()?;
244
245    loop {
246        if candidate.range.start > symbol.range.end {
247            return None;
248        }
249        if candidate.range.end < symbol.range.start {
250            candidate = candidates.next()?;
251            continue;
252        }
253        if candidate.selection_range == symbol.range {
254            return Some(candidate.clone());
255        }
256        if candidate.range.start <= symbol.range.start && symbol.range.end <= candidate.range.end {
257            candidates = candidate.children.iter();
258            candidate = candidates.next()?;
259            continue;
260        }
261        return None;
262    }
263}
264
265pub struct SymbolMatch {
266    pub symbol: Symbol,
267}
268
269pub(crate) fn search_symbols(
270    query: String,
271    cancellation_flag: Arc<AtomicBool>,
272    workspace: &Entity<Workspace>,
273    cx: &mut App,
274) -> Task<Vec<SymbolMatch>> {
275    let symbols_task = workspace.update(cx, |workspace, cx| {
276        workspace
277            .project()
278            .update(cx, |project, cx| project.symbols(&query, cx))
279    });
280    let project = workspace.read(cx).project().clone();
281    cx.spawn(async move |cx| {
282        let Some(symbols) = symbols_task.await.log_err() else {
283            return Vec::new();
284        };
285        let Some((visible_match_candidates, external_match_candidates)): Option<(Vec<_>, Vec<_>)> =
286            project
287                .update(cx, |project, cx| {
288                    symbols
289                        .iter()
290                        .enumerate()
291                        .map(|(id, symbol)| {
292                            StringMatchCandidate::new(id, symbol.label.filter_text())
293                        })
294                        .partition(|candidate| {
295                            project
296                                .entry_for_path(&symbols[candidate.id].path, cx)
297                                .is_some_and(|e| !e.is_ignored)
298                        })
299                })
300                .log_err()
301        else {
302            return Vec::new();
303        };
304
305        const MAX_MATCHES: usize = 100;
306        let mut visible_matches = cx.background_executor().block(fuzzy::match_strings(
307            &visible_match_candidates,
308            &query,
309            false,
310            true,
311            MAX_MATCHES,
312            &cancellation_flag,
313            cx.background_executor().clone(),
314        ));
315        let mut external_matches = cx.background_executor().block(fuzzy::match_strings(
316            &external_match_candidates,
317            &query,
318            false,
319            true,
320            MAX_MATCHES - visible_matches.len().min(MAX_MATCHES),
321            &cancellation_flag,
322            cx.background_executor().clone(),
323        ));
324        let sort_key_for_match = |mat: &StringMatch| {
325            let symbol = &symbols[mat.candidate_id];
326            (Reverse(OrderedFloat(mat.score)), symbol.label.filter_text())
327        };
328
329        visible_matches.sort_unstable_by_key(sort_key_for_match);
330        external_matches.sort_unstable_by_key(sort_key_for_match);
331        let mut matches = visible_matches;
332        matches.append(&mut external_matches);
333
334        matches
335            .into_iter()
336            .map(|mut mat| {
337                let symbol = symbols[mat.candidate_id].clone();
338                let filter_start = symbol.label.filter_range.start;
339                for position in &mut mat.positions {
340                    *position += filter_start;
341                }
342                SymbolMatch { symbol }
343            })
344            .collect()
345    })
346}
347
348fn compute_symbol_entries(
349    symbols: Vec<SymbolMatch>,
350    context_store: &ContextStore,
351    cx: &App,
352) -> Vec<SymbolEntry> {
353    symbols
354        .into_iter()
355        .map(|SymbolMatch { symbol, .. }| SymbolEntry {
356            is_included: context_store.includes_symbol(&symbol, cx),
357            symbol,
358        })
359        .collect::<Vec<_>>()
360}
361
362pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
363    let path = entry
364        .symbol
365        .path
366        .path
367        .file_name()
368        .map(|s| s.to_string_lossy())
369        .unwrap_or_default();
370    let symbol_location = format!("{} L{}", path, entry.symbol.range.start.0.row + 1);
371
372    h_flex()
373        .id(id)
374        .gap_1p5()
375        .w_full()
376        .child(
377            Icon::new(IconName::Code)
378                .size(IconSize::Small)
379                .color(Color::Muted),
380        )
381        .child(
382            h_flex()
383                .gap_1()
384                .child(Label::new(&entry.symbol.name))
385                .child(
386                    Label::new(symbol_location)
387                        .size(LabelSize::Small)
388                        .color(Color::Muted),
389                ),
390        )
391        .when(entry.is_included, |el| {
392            el.child(
393                h_flex()
394                    .w_full()
395                    .justify_end()
396                    .gap_0p5()
397                    .child(
398                        Icon::new(IconName::Check)
399                            .size(IconSize::Small)
400                            .color(Color::Success),
401                    )
402                    .child(Label::new("Added").size(LabelSize::Small)),
403            )
404        })
405}