symbol_context_picker.rs

  1use std::cmp::Reverse;
  2use std::sync::Arc;
  3use std::sync::atomic::AtomicBool;
  4
  5use anyhow::Result;
  6use fuzzy::{StringMatch, StringMatchCandidate};
  7use gpui::{
  8    App, AppContext, DismissEvent, Entity, FocusHandle, Focusable, Stateful, Task, WeakEntity,
  9};
 10use ordered_float::OrderedFloat;
 11use picker::{Picker, PickerDelegate};
 12use project::{DocumentSymbol, Symbol};
 13use text::OffsetRangeExt;
 14use ui::{ListItem, prelude::*};
 15use util::ResultExt as _;
 16use workspace::Workspace;
 17
 18use crate::context_picker::ContextPicker;
 19use crate::context_store::ContextStore;
 20
 21pub struct SymbolContextPicker {
 22    picker: Entity<Picker<SymbolContextPickerDelegate>>,
 23}
 24
 25impl SymbolContextPicker {
 26    pub fn new(
 27        context_picker: WeakEntity<ContextPicker>,
 28        workspace: WeakEntity<Workspace>,
 29        context_store: WeakEntity<ContextStore>,
 30        window: &mut Window,
 31        cx: &mut Context<Self>,
 32    ) -> Self {
 33        let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store);
 34        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
 35
 36        Self { picker }
 37    }
 38}
 39
 40impl Focusable for SymbolContextPicker {
 41    fn focus_handle(&self, cx: &App) -> FocusHandle {
 42        self.picker.focus_handle(cx)
 43    }
 44}
 45
 46impl Render for SymbolContextPicker {
 47    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 48        self.picker.clone()
 49    }
 50}
 51
 52pub struct SymbolContextPickerDelegate {
 53    context_picker: WeakEntity<ContextPicker>,
 54    workspace: WeakEntity<Workspace>,
 55    context_store: WeakEntity<ContextStore>,
 56    matches: Vec<SymbolEntry>,
 57    selected_index: usize,
 58}
 59
 60impl SymbolContextPickerDelegate {
 61    pub fn new(
 62        context_picker: WeakEntity<ContextPicker>,
 63        workspace: WeakEntity<Workspace>,
 64        context_store: WeakEntity<ContextStore>,
 65    ) -> Self {
 66        Self {
 67            context_picker,
 68            workspace,
 69            context_store,
 70            matches: Vec::new(),
 71            selected_index: 0,
 72        }
 73    }
 74}
 75
 76impl PickerDelegate for SymbolContextPickerDelegate {
 77    type ListItem = ListItem;
 78
 79    fn match_count(&self) -> usize {
 80        self.matches.len()
 81    }
 82
 83    fn selected_index(&self) -> usize {
 84        self.selected_index
 85    }
 86
 87    fn set_selected_index(
 88        &mut self,
 89        ix: usize,
 90        _window: &mut Window,
 91        _cx: &mut Context<Picker<Self>>,
 92    ) {
 93        self.selected_index = ix;
 94    }
 95
 96    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
 97        "Search symbols…".into()
 98    }
 99
100    fn update_matches(
101        &mut self,
102        query: String,
103        window: &mut Window,
104        cx: &mut Context<Picker<Self>>,
105    ) -> Task<()> {
106        let Some(workspace) = self.workspace.upgrade() else {
107            return Task::ready(());
108        };
109
110        let search_task = search_symbols(query, Arc::<AtomicBool>::default(), &workspace, cx);
111        let context_store = self.context_store.clone();
112        cx.spawn_in(window, async move |this, cx| {
113            let symbols = search_task.await;
114
115            let symbol_entries = context_store
116                .read_with(cx, |context_store, cx| {
117                    compute_symbol_entries(symbols, context_store, cx)
118                })
119                .log_err()
120                .unwrap_or_default();
121
122            this.update(cx, |this, _cx| {
123                this.delegate.matches = symbol_entries;
124            })
125            .log_err();
126        })
127    }
128
129    fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
130        let Some(mat) = self.matches.get(self.selected_index) else {
131            return;
132        };
133        let Some(workspace) = self.workspace.upgrade() else {
134            return;
135        };
136
137        let add_symbol_task = add_symbol(
138            mat.symbol.clone(),
139            true,
140            workspace,
141            self.context_store.clone(),
142            cx,
143        );
144
145        let selected_index = self.selected_index;
146        cx.spawn(async move |this, cx| {
147            let included = add_symbol_task.await?;
148            this.update(cx, |this, _| {
149                if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
150                    mat.is_included = included;
151                }
152            })
153        })
154        .detach_and_log_err(cx);
155    }
156
157    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
158        self.context_picker
159            .update(cx, |_, cx| {
160                cx.emit(DismissEvent);
161            })
162            .ok();
163    }
164
165    fn render_match(
166        &self,
167        ix: usize,
168        selected: bool,
169        _window: &mut Window,
170        _: &mut Context<Picker<Self>>,
171    ) -> Option<Self::ListItem> {
172        let mat = &self.matches[ix];
173
174        Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
175            render_symbol_context_entry(
176                ElementId::NamedInteger("symbol-ctx-picker".into(), ix),
177                mat,
178            ),
179        ))
180    }
181}
182
183pub(crate) struct SymbolEntry {
184    pub symbol: Symbol,
185    pub is_included: bool,
186}
187
188pub(crate) fn add_symbol(
189    symbol: Symbol,
190    remove_if_exists: bool,
191    workspace: Entity<Workspace>,
192    context_store: WeakEntity<ContextStore>,
193    cx: &mut App,
194) -> Task<Result<bool>> {
195    let project = workspace.read(cx).project().clone();
196    let open_buffer_task = project.update(cx, |project, cx| {
197        project.open_buffer(symbol.path.clone(), cx)
198    });
199    cx.spawn(async move |cx| {
200        let buffer = open_buffer_task.await?;
201        let document_symbols = project
202            .update(cx, |project, cx| project.document_symbols(&buffer, cx))?
203            .await?;
204
205        // Try to find a matching document symbol. Document symbols include
206        // not only the symbol itself (e.g. function name), but they also
207        // include the context that they contain (e.g. function body).
208        let (name, range, enclosing_range) = if let Some(DocumentSymbol {
209            name,
210            range,
211            selection_range,
212            ..
213        }) =
214            find_matching_symbol(&symbol, document_symbols.as_slice())
215        {
216            (name, selection_range, range)
217        } else {
218            // If we do not find a matching document symbol, fall back to
219            // just the symbol itself
220            (symbol.name, symbol.range.clone(), symbol.range)
221        };
222
223        let (range, enclosing_range) = buffer.read_with(cx, |buffer, _| {
224            (
225                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
226                buffer.anchor_after(enclosing_range.start)
227                    ..buffer.anchor_before(enclosing_range.end),
228            )
229        })?;
230
231        context_store
232            .update(cx, move |context_store, cx| {
233                context_store.add_symbol(
234                    buffer,
235                    name.into(),
236                    range,
237                    enclosing_range,
238                    remove_if_exists,
239                    cx,
240                )
241            })?
242            .await
243    })
244}
245
246fn find_matching_symbol(symbol: &Symbol, candidates: &[DocumentSymbol]) -> Option<DocumentSymbol> {
247    let mut candidates = candidates.iter();
248    let mut candidate = candidates.next()?;
249
250    loop {
251        if candidate.range.start > symbol.range.end {
252            return None;
253        }
254        if candidate.range.end < symbol.range.start {
255            candidate = candidates.next()?;
256            continue;
257        }
258        if candidate.selection_range == symbol.range {
259            return Some(candidate.clone());
260        }
261        if candidate.range.start <= symbol.range.start && symbol.range.end <= candidate.range.end {
262            candidates = candidate.children.iter();
263            candidate = candidates.next()?;
264            continue;
265        }
266        return None;
267    }
268}
269
270pub struct SymbolMatch {
271    pub symbol: Symbol,
272}
273
274pub(crate) fn search_symbols(
275    query: String,
276    cancellation_flag: Arc<AtomicBool>,
277    workspace: &Entity<Workspace>,
278    cx: &mut App,
279) -> Task<Vec<SymbolMatch>> {
280    let symbols_task = workspace.update(cx, |workspace, cx| {
281        workspace
282            .project()
283            .update(cx, |project, cx| project.symbols(&query, cx))
284    });
285    let project = workspace.read(cx).project().clone();
286    cx.spawn(async move |cx| {
287        let Some(symbols) = symbols_task.await.log_err() else {
288            return Vec::new();
289        };
290        let Some((visible_match_candidates, external_match_candidates)): Option<(Vec<_>, Vec<_>)> =
291            project
292                .update(cx, |project, cx| {
293                    symbols
294                        .iter()
295                        .enumerate()
296                        .map(|(id, symbol)| {
297                            StringMatchCandidate::new(id, &symbol.label.filter_text())
298                        })
299                        .partition(|candidate| {
300                            project
301                                .entry_for_path(&symbols[candidate.id].path, cx)
302                                .map_or(false, |e| !e.is_ignored)
303                        })
304                })
305                .log_err()
306        else {
307            return Vec::new();
308        };
309
310        const MAX_MATCHES: usize = 100;
311        let mut visible_matches = cx.background_executor().block(fuzzy::match_strings(
312            &visible_match_candidates,
313            &query,
314            false,
315            MAX_MATCHES,
316            &cancellation_flag,
317            cx.background_executor().clone(),
318        ));
319        let mut external_matches = cx.background_executor().block(fuzzy::match_strings(
320            &external_match_candidates,
321            &query,
322            false,
323            MAX_MATCHES - visible_matches.len().min(MAX_MATCHES),
324            &cancellation_flag,
325            cx.background_executor().clone(),
326        ));
327        let sort_key_for_match = |mat: &StringMatch| {
328            let symbol = &symbols[mat.candidate_id];
329            (Reverse(OrderedFloat(mat.score)), symbol.label.filter_text())
330        };
331
332        visible_matches.sort_unstable_by_key(sort_key_for_match);
333        external_matches.sort_unstable_by_key(sort_key_for_match);
334        let mut matches = visible_matches;
335        matches.append(&mut external_matches);
336
337        matches
338            .into_iter()
339            .map(|mut mat| {
340                let symbol = symbols[mat.candidate_id].clone();
341                let filter_start = symbol.label.filter_range.start;
342                for position in &mut mat.positions {
343                    *position += filter_start;
344                }
345                SymbolMatch { symbol }
346            })
347            .collect()
348    })
349}
350
351fn compute_symbol_entries(
352    symbols: Vec<SymbolMatch>,
353    context_store: &ContextStore,
354    cx: &App,
355) -> Vec<SymbolEntry> {
356    let mut symbol_entries = Vec::with_capacity(symbols.len());
357    for SymbolMatch { symbol, .. } in symbols {
358        let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
359        let is_included = if let Some(symbols_for_path) = symbols_for_path {
360            let mut is_included = false;
361            for included_symbol_id in symbols_for_path {
362                if included_symbol_id.name.as_ref() == symbol.name.as_str() {
363                    if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
364                        let snapshot = buffer.read(cx).snapshot();
365                        let included_symbol_range =
366                            included_symbol_id.range.to_point_utf16(&snapshot);
367
368                        if included_symbol_range.start == symbol.range.start.0
369                            && included_symbol_range.end == symbol.range.end.0
370                        {
371                            is_included = true;
372                            break;
373                        }
374                    }
375                }
376            }
377            is_included
378        } else {
379            false
380        };
381
382        symbol_entries.push(SymbolEntry {
383            symbol,
384            is_included,
385        })
386    }
387    symbol_entries
388}
389
390pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
391    let path = entry
392        .symbol
393        .path
394        .path
395        .file_name()
396        .map(|s| s.to_string_lossy())
397        .unwrap_or_default();
398    let symbol_location = format!("{} L{}", path, entry.symbol.range.start.0.row + 1);
399
400    h_flex()
401        .id(id)
402        .gap_1p5()
403        .w_full()
404        .child(
405            Icon::new(IconName::Code)
406                .size(IconSize::Small)
407                .color(Color::Muted),
408        )
409        .child(
410            h_flex()
411                .gap_1()
412                .child(Label::new(&entry.symbol.name))
413                .child(
414                    Label::new(symbol_location)
415                        .size(LabelSize::Small)
416                        .color(Color::Muted),
417                ),
418        )
419        .when(entry.is_included, |el| {
420            el.child(
421                h_flex()
422                    .w_full()
423                    .justify_end()
424                    .gap_0p5()
425                    .child(
426                        Icon::new(IconName::Check)
427                            .size(IconSize::Small)
428                            .color(Color::Success),
429                    )
430                    .child(Label::new("Added").size(LabelSize::Small)),
431            )
432        })
433}