tree_sitter_index.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
  4use language::{Buffer, BufferEvent, BufferSnapshot, OutlineItem};
  5use project::buffer_store::{BufferStore, BufferStoreEvent};
  6use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
  7use project::{PathChange, Project, ProjectEntryId, ProjectPath};
  8use std::ops::Range;
  9use std::sync::Arc;
 10use text::{Anchor, OffsetRangeExt as _};
 11use util::ResultExt as _;
 12
 13// TODO:
 14//
 15// * Need an efficient way to get outline parents (see parents field / outline_id in
 16// `zeta_context/src/outline.rs`, as well as logic for figuring it out). Could be indexes into
 17// `declarations` instead of the OutlineId mechanism.
 18//
 19// * Skip for remote projects
 20
 21// Potential future optimizations:
 22//
 23// * Cache of buffers for files
 24//
 25// * Parse files directly instead of loading into a Rope.
 26
 27pub struct TreeSitterIndex {
 28    files: HashMap<ProjectEntryId, FileState>,
 29    buffers: HashMap<WeakEntity<Buffer>, BufferState>,
 30    project: WeakEntity<Project>,
 31}
 32
 33#[derive(Debug, Default)]
 34struct FileState {
 35    declarations: Vec<FileDeclaration>,
 36    task: Option<Task<()>>,
 37}
 38
 39#[derive(Default)]
 40struct BufferState {
 41    declarations: Vec<BufferDeclaration>,
 42    task: Option<Task<()>>,
 43}
 44
 45#[derive(Debug, Clone)]
 46pub enum Declaration {
 47    File {
 48        file: ProjectEntryId,
 49        declaration: FileDeclaration,
 50    },
 51    Buffer {
 52        buffer: WeakEntity<Buffer>,
 53        declaration: BufferDeclaration,
 54    },
 55}
 56
 57#[derive(Debug, Clone)]
 58pub struct FileDeclaration {
 59    identifier: Identifier,
 60    item_range: Range<usize>,
 61    annotation_range: Option<Range<usize>>,
 62    signature_range: Range<usize>,
 63    signature_text: String,
 64}
 65
 66#[derive(Debug, Clone)]
 67pub struct BufferDeclaration {
 68    identifier: Identifier,
 69    item_range: Range<Anchor>,
 70    annotation_range: Option<Range<Anchor>>,
 71    signature_range: Range<Anchor>,
 72    signature_text: String,
 73}
 74
 75pub struct DeclarationText {
 76    text: String,
 77    // Offset range within the `text` field containing the lines of the signature.
 78    signature_range: Range<usize>,
 79}
 80
 81#[derive(Debug, Clone, Eq, PartialEq, Hash)]
 82pub struct Identifier(Arc<str>);
 83
 84impl<T: Into<Arc<str>>> From<T> for Identifier {
 85    fn from(value: T) -> Self {
 86        Identifier(value.into())
 87    }
 88}
 89
 90impl TreeSitterIndex {
 91    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 92        let mut this = Self {
 93            project: project.downgrade(),
 94            files: HashMap::default(),
 95            buffers: HashMap::default(),
 96        };
 97
 98        let worktree_store = project.read(cx).worktree_store();
 99        cx.subscribe(&worktree_store, Self::handle_worktree_store_event)
100            .detach();
101
102        for worktree in worktree_store
103            .read(cx)
104            .worktrees()
105            .map(|w| w.read(cx).snapshot())
106            .collect::<Vec<_>>()
107        {
108            // todo! bg?
109            for entry in worktree.files(false, 0) {
110                this.update_file(
111                    entry.id,
112                    ProjectPath {
113                        worktree_id: worktree.id(),
114                        path: entry.path.clone(),
115                    },
116                    cx,
117                );
118            }
119        }
120
121        let buffer_store = project.read(cx).buffer_store().clone();
122        for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
123            this.register_buffer(&buffer, cx);
124        }
125        cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
126            .detach();
127
128        this
129    }
130
131    pub fn declarations_for_identifier<const N: usize>(
132        &self,
133        identifier: impl Into<Identifier>,
134        cx: &App,
135    ) -> Vec<Declaration> {
136        assert!(N < 32);
137
138        let identifier = identifier.into();
139        let mut declarations = Vec::with_capacity(N);
140        // THEORY: set would be slower given the avg. number of buffers
141        let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
142
143        for (buffer, buffer_state) in &self.buffers {
144            let mut included = false;
145            for declaration in &buffer_state.declarations {
146                if declaration.identifier == identifier {
147                    declarations.push(Declaration::Buffer {
148                        buffer: buffer.clone(),
149                        declaration: declaration.clone(),
150                    });
151                    included = true;
152
153                    if declarations.len() == N {
154                        return declarations;
155                    }
156                }
157            }
158            if included
159                && let Ok(Some(entry)) = buffer.read_with(cx, |buffer, cx| {
160                    project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
161                })
162            {
163                included_buffer_entry_ids.push(entry);
164            }
165        }
166
167        for (file, file_state) in &self.files {
168            if included_buffer_entry_ids.contains(file) {
169                continue;
170            }
171
172            for declaration in &file_state.declarations {
173                if declaration.identifier == identifier {
174                    declarations.push(Declaration::File {
175                        file: *file,
176                        declaration: declaration.clone(),
177                    });
178
179                    if declarations.len() == N {
180                        return declarations;
181                    }
182                }
183            }
184        }
185
186        declarations
187    }
188
189    fn handle_worktree_store_event(
190        &mut self,
191        _worktree_store: Entity<WorktreeStore>,
192        event: &WorktreeStoreEvent,
193        cx: &mut Context<Self>,
194    ) {
195        use WorktreeStoreEvent::*;
196        match event {
197            WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
198                for (path, entry_id, path_change) in updated_entries_set.iter() {
199                    if let PathChange::Removed = path_change {
200                        self.files.remove(entry_id);
201                    } else {
202                        let project_path = ProjectPath {
203                            worktree_id: *worktree_id,
204                            path: path.clone(),
205                        };
206                        self.update_file(*entry_id, project_path, cx);
207                    }
208                }
209            }
210            WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
211                // TODO: Is this needed?
212                self.files.remove(project_entry_id);
213            }
214            _ => {}
215        }
216    }
217
218    fn handle_buffer_store_event(
219        &mut self,
220        _buffer_store: Entity<BufferStore>,
221        event: &BufferStoreEvent,
222        cx: &mut Context<Self>,
223    ) {
224        use BufferStoreEvent::*;
225        match event {
226            BufferAdded(buffer) => self.register_buffer(buffer, cx),
227            BufferOpened { .. }
228            | BufferChangedFilePath { .. }
229            | BufferDropped { .. }
230            | SharedBufferClosed { .. } => {}
231        }
232    }
233
234    fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
235        self.buffers
236            .insert(buffer.downgrade(), BufferState::default());
237        let weak_buf = buffer.downgrade();
238        cx.observe_release(buffer, move |this, _buffer, _cx| {
239            this.buffers.remove(&weak_buf);
240        })
241        .detach();
242        cx.subscribe(buffer, Self::handle_buffer_event).detach();
243        self.update_buffer(buffer.clone(), cx);
244    }
245
246    fn handle_buffer_event(
247        &mut self,
248        buffer: Entity<Buffer>,
249        event: &BufferEvent,
250        cx: &mut Context<Self>,
251    ) {
252        match event {
253            BufferEvent::Edited => self.update_buffer(buffer, cx),
254            _ => {}
255        }
256    }
257
258    fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &Context<Self>) {
259        let buffer = buffer_entity.read(cx);
260
261        let snapshot = buffer.snapshot();
262        let parse_task = cx.background_spawn(async move {
263            snapshot
264                .outline(None)
265                .items
266                .into_iter()
267                .filter_map(BufferDeclaration::try_from_outline_item)
268                .collect()
269        });
270
271        let task = cx.spawn({
272            let weak_buffer = buffer_entity.downgrade();
273            async move |this, cx| {
274                let declarations = parse_task.await;
275                this.update(cx, |this, _cx| {
276                    this.buffers
277                        .entry(weak_buffer)
278                        .or_insert_with(Default::default)
279                        .declarations = declarations;
280                })
281                .ok();
282            }
283        });
284
285        self.buffers
286            .entry(buffer_entity.downgrade())
287            .or_insert_with(Default::default)
288            .task = Some(task);
289    }
290
291    fn update_file(
292        &mut self,
293        entry_id: ProjectEntryId,
294        project_path: ProjectPath,
295        cx: &mut Context<Self>,
296    ) {
297        let Some(project) = self.project.upgrade() else {
298            return;
299        };
300        let project = project.read(cx);
301        let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else {
302            return;
303        };
304        let language_registry = project.languages().clone();
305
306        let snapshot_task = worktree.update(cx, |worktree, cx| {
307            let load_task = worktree.load_file(&project_path.path, cx);
308            cx.spawn(async move |_this, cx| {
309                let loaded_file = load_task.await?;
310                let language = language_registry
311                    .language_for_file_path(&project_path.path)
312                    .await
313                    .log_err();
314
315                let buffer = cx.new(|cx| {
316                    let mut buffer = Buffer::local(loaded_file.text, cx);
317                    buffer.set_language(language, cx);
318                    buffer
319                })?;
320                buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
321            })
322        });
323
324        let parse_task: Task<Result<Vec<FileDeclaration>>> = cx.background_spawn(async move {
325            let snapshot = snapshot_task.await?;
326            Ok(snapshot
327                .outline(None)
328                .items
329                .into_iter()
330                .filter_map(BufferDeclaration::try_from_outline_item)
331                .map(|declaration| declaration.into_file_declaration(&snapshot))
332                .collect())
333        });
334
335        let task = cx.spawn({
336            async move |this, cx| {
337                // TODO: how to handle errors?
338                let Ok(declarations) = parse_task.await else {
339                    return;
340                };
341                this.update(cx, |this, _cx| {
342                    this.files
343                        .entry(entry_id)
344                        .or_insert_with(Default::default)
345                        .declarations = declarations;
346                })
347                .ok();
348            }
349        });
350
351        self.files
352            .entry(entry_id)
353            .or_insert_with(Default::default)
354            .task = Some(task);
355    }
356}
357
358impl BufferDeclaration {
359    pub fn try_from_outline_item(item: OutlineItem<Anchor>) -> Option<Self> {
360        // todo! what to do about multiple names?
361        let name_range = item.name_ranges.get(0)?;
362        Some(BufferDeclaration {
363            identifier: Identifier(item.text[name_range.clone()].into()),
364            item_range: item.range,
365            annotation_range: item.annotation_range,
366            signature_range: item.signature_range?,
367            // todo! this should instead be the signature_range but expanded to line boundaries.
368            signature_text: item.text.clone(),
369        })
370    }
371
372    pub fn into_file_declaration(self, snapshot: &BufferSnapshot) -> FileDeclaration {
373        FileDeclaration {
374            identifier: self.identifier,
375            item_range: self.item_range.to_offset(snapshot),
376            annotation_range: self.annotation_range.map(|range| range.to_offset(snapshot)),
377            signature_range: self.signature_range.to_offset(snapshot),
378            signature_text: self.signature_text.clone(),
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use std::{path::Path, sync::Arc};
387
388    use futures::channel::oneshot;
389    use gpui::TestAppContext;
390    use indoc::indoc;
391    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
392    use project::{FakeFs, Project, ProjectItem};
393    use serde_json::json;
394    use settings::SettingsStore;
395    use util::path;
396
397    use crate::tree_sitter_index::TreeSitterIndex;
398
399    #[gpui::test]
400    async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
401        let (project, index) = init_test(cx).await;
402
403        index.read_with(cx, |index, cx| {
404            let decls = index.declarations_for_identifier::<8>("main", cx);
405            assert_eq!(decls.len(), 2);
406
407            let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
408            assert_eq!(decl.identifier, "main".into());
409            assert_eq!(decl.item_range, 32..279);
410
411            let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
412            assert_eq!(decl.identifier, "main".into());
413            assert_eq!(decl.item_range, 0..97);
414        });
415    }
416
417    #[gpui::test]
418    async fn test_declarations_limt(cx: &mut TestAppContext) {
419        let (_, index) = init_test(cx).await;
420
421        // todo! test with buffers
422        index.read_with(cx, |index, cx| {
423            let decls = index.declarations_for_identifier::<1>("main", cx);
424            assert_eq!(decls.len(), 1);
425        });
426    }
427
428    #[gpui::test]
429    async fn test_buffer_shadow(cx: &mut TestAppContext) {
430        let (project, index) = init_test(cx).await;
431
432        let buffer = project
433            .update(cx, |project, cx| {
434                let project_path = project.find_project_path("c.rs", cx).unwrap();
435                project.open_buffer(project_path, cx)
436            })
437            .await
438            .unwrap();
439
440        cx.run_until_parked();
441
442        index.read_with(cx, |index, cx| {
443            let decls = index.declarations_for_identifier::<8>("main", cx);
444            assert_eq!(decls.len(), 2);
445
446            let decl = expect_buffer_decl("c.rs", &decls[0], cx);
447            assert_eq!(decl.identifier, "main".into());
448            assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
449
450            expect_file_decl("a.rs", &decls[1], &project, cx);
451        });
452
453        // Drop the buffer and wait for release
454        let (release_tx, release_rx) = oneshot::channel();
455        cx.update(|cx| {
456            cx.observe_release(&buffer, |_, _| {
457                release_tx.send(()).ok();
458            })
459            .detach();
460        });
461        drop(buffer);
462        cx.run_until_parked();
463        release_rx.await.ok();
464        cx.run_until_parked();
465
466        index.read_with(cx, |index, cx| {
467            let decls = index.declarations_for_identifier::<8>("main", cx);
468            assert_eq!(decls.len(), 2);
469            expect_file_decl("c.rs", &decls[0], &project, cx);
470            expect_file_decl("a.rs", &decls[1], &project, cx);
471        });
472    }
473
474    fn expect_buffer_decl<'a>(
475        path: &str,
476        declaration: &'a Declaration,
477        cx: &App,
478    ) -> &'a BufferDeclaration {
479        if let Declaration::Buffer {
480            declaration,
481            buffer,
482        } = declaration
483        {
484            assert_eq!(
485                buffer
486                    .upgrade()
487                    .unwrap()
488                    .read(cx)
489                    .project_path(cx)
490                    .unwrap()
491                    .path
492                    .as_ref(),
493                Path::new(path),
494            );
495            declaration
496        } else {
497            panic!("Expected a buffer declaration, found {:?}", declaration);
498        }
499    }
500
501    fn expect_file_decl<'a>(
502        path: &str,
503        declaration: &'a Declaration,
504        project: &Entity<Project>,
505        cx: &App,
506    ) -> &'a FileDeclaration {
507        if let Declaration::File { declaration, file } = declaration {
508            assert_eq!(
509                project
510                    .read(cx)
511                    .path_for_entry(*file, cx)
512                    .unwrap()
513                    .path
514                    .as_ref(),
515                Path::new(path),
516            );
517            declaration
518        } else {
519            panic!("Expected a file declaration, found {:?}", declaration);
520        }
521    }
522
523    async fn init_test(cx: &mut TestAppContext) -> (Entity<Project>, Entity<TreeSitterIndex>) {
524        cx.update(|cx| {
525            let settings_store = SettingsStore::test(cx);
526            cx.set_global(settings_store);
527            language::init(cx);
528            Project::init_settings(cx);
529        });
530
531        let fs = FakeFs::new(cx.executor());
532        fs.insert_tree(
533            path!("/root"),
534            json!({
535                "a.rs": indoc! {r#"
536                    fn main() {
537                        let x = 1;
538                        let y = 2;
539                        let z = add(x, y);
540                        println!("Result: {}", z);
541                    }
542
543                    fn add(a: i32, b: i32) -> i32 {
544                        a + b
545                    }
546                "#},
547                "b.rs": indoc! {"
548                    pub struct Config {
549                        pub name: String,
550                        pub value: i32,
551                    }
552
553                    impl Config {
554                        pub fn new(name: String, value: i32) -> Self {
555                            Config { name, value }
556                        }
557                    }
558                "},
559                "c.rs": indoc! {r#"
560                    use std::collections::HashMap;
561
562                    fn main() {
563                        let args: Vec<String> = std::env::args().collect();
564                        let data: Vec<i32> = args[1..]
565                            .iter()
566                            .filter_map(|s| s.parse().ok())
567                            .collect();
568                        let result = process_data(data);
569                        println!("{:?}", result);
570                    }
571
572                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
573                        let mut counts = HashMap::new();
574                        for value in data {
575                            *counts.entry(value).or_insert(0) += 1;
576                        }
577                        counts
578                    }
579
580                    #[cfg(test)]
581                    mod tests {
582                        use super::*;
583
584                        #[test]
585                        fn test_process_data() {
586                            let data = vec![1, 2, 2, 3];
587                            let result = process_data(data);
588                            assert_eq!(result.get(&2), Some(&2));
589                        }
590                    }
591                "#}
592            }),
593        )
594        .await;
595        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
596        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
597        language_registry.add(Arc::new(rust_lang()));
598
599        let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
600        cx.run_until_parked();
601
602        (project, index)
603    }
604
605    fn rust_lang() -> Language {
606        Language::new(
607            LanguageConfig {
608                name: "Rust".into(),
609                matcher: LanguageMatcher {
610                    path_suffixes: vec!["rs".to_string()],
611                    ..Default::default()
612                },
613                ..Default::default()
614            },
615            Some(tree_sitter_rust::LANGUAGE.into()),
616        )
617        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
618        .unwrap()
619    }
620}
621
622/*
623#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize)]
624#[serde(transparent)]
625pub struct Identifier(pub Arc<str>);
626
627#[derive(Debug)]
628pub struct IdentifierIndex {
629    pub identifier_to_definitions:
630        HashMap<(Identifier, LanguageName), MultiMap<Arc<Path>, OutlineItem>>,
631    pub path_to_source: HashMap<Arc<Path>, String>,
632    pub path_to_items: HashMap<Arc<Path>, Vec<OutlineItem>>,
633    pub outline_id_to_item: HashMap<OutlineId, OutlineItem>,
634}
635
636impl IdentifierIndex {
637    pub fn index_path(languages: &[Arc<Language>], path: &Path) -> Result<IdentifierIndex> {
638        let mut identifier_to_definitions = HashMap::new();
639        let mut path_to_source = HashMap::new();
640        let mut path_to_items = HashMap::new();
641        let mut outline_id_to_item = HashMap::new();
642
643        for entry in Walk::new(path)
644            .into_iter()
645            .filter_map(|e| e.ok())
646            .filter(|e| e.metadata().unwrap().is_file())
647        {
648            let file_path = entry.path();
649            let Some(language) = language_for_file(languages, file_path) else {
650                continue;
651            };
652            if !language.supports_references {
653                continue;
654            }
655            let source = fs::read_to_string(file_path)
656                .map_err(|e| anyhow!("Failed to read file {:?}: {}", file_path, e))?;
657            let tree = parse_source(&language, &source);
658
659            let mut outline_items = query_outline_items(&language, &tree, &source);
660            outline_items.sort_by_key(|item| item.item_range.start);
661            for outline_item in outline_items.iter() {
662                let identifier = Identifier(outline_item.name(&source).into());
663                let definitions: &mut MultiMap<Arc<Path>, OutlineItem> = identifier_to_definitions
664                    .entry((identifier, language.name.clone()))
665                    .or_default();
666                definitions.insert(file_path.into(), outline_item.clone());
667                outline_id_to_item.insert(outline_item.id, outline_item.clone());
668            }
669            path_to_source.insert(file_path.into(), source);
670            path_to_items.insert(file_path.into(), outline_items);
671        }
672
673        Ok(IdentifierIndex {
674            identifier_to_definitions,
675            path_to_source,
676            path_to_items,
677            outline_id_to_item,
678        })
679    }
680}
681*/