zeta2: Include type definitions in related files (#49748)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction_context/src/edit_prediction_context.rs       |  88 
crates/edit_prediction_context/src/edit_prediction_context_tests.rs | 280 
crates/edit_prediction_context/src/fake_definition_lsp.rs           | 129 
3 files changed, 476 insertions(+), 21 deletions(-)

Detailed changes

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -65,12 +65,16 @@ struct Identifier {
 
 enum DefinitionTask {
     CacheHit(Arc<CacheEntry>),
-    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
+    CacheMiss {
+        definitions: Task<Result<Option<Vec<LocationLink>>>>,
+        type_definitions: Task<Result<Option<Vec<LocationLink>>>>,
+    },
 }
 
 #[derive(Debug)]
 struct CacheEntry {
     definitions: SmallVec<[CachedDefinition; 1]>,
+    type_definitions: SmallVec<[CachedDefinition; 1]>,
 }
 
 #[derive(Clone, Debug)]
@@ -232,13 +236,22 @@ impl RelatedExcerptStore {
                     let task = if let Some(entry) = this.cache.get(&identifier) {
                         DefinitionTask::CacheHit(entry.clone())
                     } else {
-                        DefinitionTask::CacheMiss(
-                            this.project
-                                .update(cx, |project, cx| {
-                                    project.definitions(&buffer, identifier.range.start, cx)
-                                })
-                                .ok()?,
-                        )
+                        let definitions = this
+                            .project
+                            .update(cx, |project, cx| {
+                                project.definitions(&buffer, identifier.range.start, cx)
+                            })
+                            .ok()?;
+                        let type_definitions = this
+                            .project
+                            .update(cx, |project, cx| {
+                                project.type_definitions(&buffer, identifier.range.start, cx)
+                            })
+                            .ok()?;
+                        DefinitionTask::CacheMiss {
+                            definitions,
+                            type_definitions,
+                        }
                     };
 
                     let cx = async_cx.clone();
@@ -248,19 +261,50 @@ impl RelatedExcerptStore {
                             DefinitionTask::CacheHit(cache_entry) => {
                                 Some((identifier, cache_entry, None))
                             }
-                            DefinitionTask::CacheMiss(task) => {
-                                let locations = task.await.log_err()??;
+                            DefinitionTask::CacheMiss {
+                                definitions,
+                                type_definitions,
+                            } => {
+                                let (definition_locations, type_definition_locations) =
+                                    futures::join!(definitions, type_definitions);
                                 let duration = start_time.elapsed();
+
+                                let definition_locations =
+                                    definition_locations.log_err().flatten().unwrap_or_default();
+                                let type_definition_locations = type_definition_locations
+                                    .log_err()
+                                    .flatten()
+                                    .unwrap_or_default();
+
                                 Some(cx.update(|cx| {
+                                    let definitions: SmallVec<[CachedDefinition; 1]> =
+                                        definition_locations
+                                            .into_iter()
+                                            .filter_map(|location| {
+                                                process_definition(location, &project, cx)
+                                            })
+                                            .collect();
+
+                                    let type_definitions: SmallVec<[CachedDefinition; 1]> =
+                                        type_definition_locations
+                                            .into_iter()
+                                            .filter_map(|location| {
+                                                process_definition(location, &project, cx)
+                                            })
+                                            .filter(|type_def| {
+                                                !definitions.iter().any(|def| {
+                                                    def.buffer.entity_id()
+                                                        == type_def.buffer.entity_id()
+                                                        && def.anchor_range == type_def.anchor_range
+                                                })
+                                            })
+                                            .collect();
+
                                     (
                                         identifier,
                                         Arc::new(CacheEntry {
-                                            definitions: locations
-                                                .into_iter()
-                                                .filter_map(|location| {
-                                                    process_definition(location, &project, cx)
-                                                })
-                                                .collect(),
+                                            definitions,
+                                            type_definitions,
                                         }),
                                         Some(duration),
                                     )
@@ -323,7 +367,11 @@ async fn rebuild_related_files(
     let mut snapshots = HashMap::default();
     let mut worktree_root_names = HashMap::default();
     for entry in new_entries.values() {
-        for definition in &entry.definitions {
+        for definition in entry
+            .definitions
+            .iter()
+            .chain(entry.type_definitions.iter())
+        {
             if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
                 definition
                     .buffer
@@ -354,7 +402,11 @@ async fn rebuild_related_files(
                 HashMap::<EntityId, (Entity<Buffer>, Vec<Range<Point>>)>::default();
             let mut paths_by_buffer = HashMap::default();
             for entry in new_entries.values_mut() {
-                for definition in &entry.definitions {
+                for definition in entry
+                    .definitions
+                    .iter()
+                    .chain(entry.type_definitions.iter())
+                {
                     let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
                         continue;
                     };

crates/edit_prediction_context/src/edit_prediction_context_tests.rs 🔗

@@ -75,6 +75,13 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
                     "root/src/person.rs",
                     &[
                         indoc! {"
+                        pub struct Person {
+                            first_name: String,
+                            last_name: String,
+                            email: String,
+                            age: u32,
+                        }
+
                         impl Person {
                             pub fn get_first_name(&self) -> &str {
                                 &self.first_name
@@ -133,6 +140,13 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
                     "root/src/person.rs",
                     &[
                         indoc! {"
+                        pub struct Person {
+                            first_name: String,
+                            last_name: String,
+                            email: String,
+                            age: u32,
+                        }
+
                         impl Person {
                             pub fn get_first_name(&self) -> &str {
                                 &self.first_name
@@ -353,6 +367,272 @@ async fn test_fake_definition_lsp(cx: &mut TestAppContext) {
     assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx);
 }
 
+#[gpui::test]
+async fn test_fake_type_definition_lsp(cx: &mut TestAppContext) {
+    init_test(cx);
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(path!("/root"), test_project_1()).await;
+
+    let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+    let mut servers = setup_fake_lsp(&project, cx);
+
+    let (buffer, _handle) = project
+        .update(cx, |project, cx| {
+            project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
+        })
+        .await
+        .unwrap();
+
+    let _server = servers.next().await.unwrap();
+    cx.run_until_parked();
+
+    let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
+
+    // Type definition on a type name returns its own definition
+    // (same as regular definition)
+    let type_defs = project
+        .update(cx, |project, cx| {
+            let offset = buffer_text.find("Address {").expect("Address { not found");
+            project.type_definitions(&buffer, offset, cx)
+        })
+        .await
+        .unwrap()
+        .unwrap();
+    assert_definitions(&type_defs, &["pub struct Address {"], cx);
+
+    // Type definition on a field resolves through the type annotation.
+    // company.rs has `owner: Arc<Person>`, so type-def of `owner` → Person.
+    let (company_buffer, _handle) = project
+        .update(cx, |project, cx| {
+            project.open_local_buffer_with_lsp(path!("/root/src/company.rs"), cx)
+        })
+        .await
+        .unwrap();
+    cx.run_until_parked();
+
+    let company_text = company_buffer.read_with(cx, |buffer, _| buffer.text());
+    let type_defs = project
+        .update(cx, |project, cx| {
+            let offset = company_text.find("owner").expect("owner not found");
+            project.type_definitions(&company_buffer, offset, cx)
+        })
+        .await
+        .unwrap()
+        .unwrap();
+    assert_definitions(&type_defs, &["pub struct Person {"], cx);
+
+    // Type definition on another field: `address: Address` → Address.
+    let type_defs = project
+        .update(cx, |project, cx| {
+            let offset = company_text.find("address").expect("address not found");
+            project.type_definitions(&company_buffer, offset, cx)
+        })
+        .await
+        .unwrap()
+        .unwrap();
+    assert_definitions(&type_defs, &["pub struct Address {"], cx);
+
+    // Type definition on a lowercase name with no type annotation returns empty.
+    let type_defs = project
+        .update(cx, |project, cx| {
+            let offset = buffer_text.find("main").expect("main not found");
+            project.type_definitions(&buffer, offset, cx)
+        })
+        .await;
+    let is_empty = match &type_defs {
+        Ok(Some(defs)) => defs.is_empty(),
+        Ok(None) => true,
+        Err(_) => false,
+    };
+    assert!(is_empty, "expected no type definitions for `main`");
+}
+
+#[gpui::test]
+async fn test_type_definitions_in_related_files(cx: &mut TestAppContext) {
+    init_test(cx);
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        path!("/root"),
+        json!({
+            "src": {
+                "config.rs": indoc! {r#"
+                    pub struct Config {
+                        debug: bool,
+                        verbose: bool,
+                    }
+                "#},
+                "widget.rs": indoc! {r#"
+                    use super::config::Config;
+
+                    pub struct Widget {
+                        config: Config,
+                        name: String,
+                    }
+
+                    impl Widget {
+                        pub fn render(&self) {
+                            if self.config.debug {
+                                println!("debug mode");
+                            }
+                        }
+                    }
+                "#},
+            },
+        }),
+    )
+    .await;
+
+    let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+    let mut servers = setup_fake_lsp(&project, cx);
+
+    let (buffer, _handle) = project
+        .update(cx, |project, cx| {
+            project.open_local_buffer_with_lsp(path!("/root/src/widget.rs"), cx)
+        })
+        .await
+        .unwrap();
+
+    let _server = servers.next().await.unwrap();
+    cx.run_until_parked();
+
+    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx));
+    related_excerpt_store.update(cx, |store, cx| {
+        let position = {
+            let buffer = buffer.read(cx);
+            let offset = buffer
+                .text()
+                .find("self.config.debug")
+                .expect("self.config.debug not found");
+            buffer.anchor_before(offset)
+        };
+
+        store.set_identifier_line_count(0);
+        store.refresh(buffer.clone(), position, cx);
+    });
+
+    cx.executor().advance_clock(DEBOUNCE_DURATION);
+    // config.rs appears ONLY because the fake LSP resolves the type annotation
+    // `config: Config` to `pub struct Config` via GotoTypeDefinition.
+    // widget.rs appears from regular definitions of Widget / render.
+    related_excerpt_store.update(cx, |store, cx| {
+        let excerpts = store.related_files(cx);
+        assert_related_files(
+            &excerpts,
+            &[
+                (
+                    "root/src/config.rs",
+                    &[indoc! {"
+                        pub struct Config {
+                            debug: bool,
+                            verbose: bool,
+                        }"}],
+                ),
+                (
+                    "root/src/widget.rs",
+                    &[
+                        indoc! {"
+                        pub struct Widget {
+                            config: Config,
+                            name: String,
+                        }
+
+                        impl Widget {
+                            pub fn render(&self) {"},
+                        indoc! {"
+                            }
+                        }"},
+                    ],
+                ),
+            ],
+        );
+    });
+}
+
+#[gpui::test]
+async fn test_type_definition_deduplication(cx: &mut TestAppContext) {
+    init_test(cx);
+    let fs = FakeFs::new(cx.executor());
+
+    // In this project the only identifier near the cursor whose type definition
+    // resolves is `TypeA`, and its GotoTypeDefinition returns the exact same
+    // location as GotoDefinition. After deduplication the CacheEntry for `TypeA`
+    // should have an empty `type_definitions` vec, meaning the type-definition
+    // path contributes nothing extra to the related-file output.
+    fs.insert_tree(
+        path!("/root"),
+        json!({
+            "src": {
+                "types.rs": indoc! {r#"
+                    pub struct TypeA {
+                        value: i32,
+                    }
+
+                    pub struct TypeB {
+                        label: String,
+                    }
+                "#},
+                "main.rs": indoc! {r#"
+                    use super::types::TypeA;
+
+                    fn work() {
+                        let item: TypeA = unimplemented!();
+                        println!("{}", item.value);
+                    }
+                "#},
+            },
+        }),
+    )
+    .await;
+
+    let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+    let mut servers = setup_fake_lsp(&project, cx);
+
+    let (buffer, _handle) = project
+        .update(cx, |project, cx| {
+            project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
+        })
+        .await
+        .unwrap();
+
+    let _server = servers.next().await.unwrap();
+    cx.run_until_parked();
+
+    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx));
+    related_excerpt_store.update(cx, |store, cx| {
+        let position = {
+            let buffer = buffer.read(cx);
+            let offset = buffer.text().find("let item").expect("let item not found");
+            buffer.anchor_before(offset)
+        };
+
+        store.set_identifier_line_count(0);
+        store.refresh(buffer.clone(), position, cx);
+    });
+
+    cx.executor().advance_clock(DEBOUNCE_DURATION);
+    // types.rs appears because `TypeA` has a regular definition there.
+    // `item`'s type definition also resolves to TypeA in types.rs, but
+    // deduplication removes it since it points to the same location.
+    // TypeB should NOT appear because nothing references it.
+    related_excerpt_store.update(cx, |store, cx| {
+        let excerpts = store.related_files(cx);
+        assert_related_files(
+            &excerpts,
+            &[
+                ("root/src/main.rs", &["fn work() {", "}"]),
+                (
+                    "root/src/types.rs",
+                    &[indoc! {"
+                        pub struct TypeA {
+                            value: i32,
+                        }"}],
+                ),
+            ],
+        );
+    });
+}
+
 fn init_test(cx: &mut TestAppContext) {
     let settings_store = cx.update(|cx| SettingsStore::test(cx));
     cx.set_global(settings_store);

crates/edit_prediction_context/src/fake_definition_lsp.rs 🔗

@@ -9,9 +9,9 @@ use project::Fs;
 use std::{ops::Range, path::PathBuf, sync::Arc};
 use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree};
 
-/// Registers a fake language server that implements go-to-definition using tree-sitter,
-/// making the assumption that all names are unique, and all variables' types are
-/// explicitly declared.
+/// Registers a fake language server that implements go-to-definition and
+/// go-to-type-definition using tree-sitter, making the assumption that all
+/// names are unique, and all variables' types are explicitly declared.
 pub fn register_fake_definition_server(
     language_registry: &Arc<LanguageRegistry>,
     language: Arc<Language>,
@@ -34,6 +34,7 @@ pub fn register_fake_definition_server(
             },
             capabilities: lsp::ServerCapabilities {
                 definition_provider: Some(lsp::OneOf::Left(true)),
+                type_definition_provider: Some(lsp::TypeDefinitionProviderCapability::Simple(true)),
                 text_document_sync: Some(TextDocumentSyncCapability::Kind(
                     TextDocumentSyncKind::FULL,
                 )),
@@ -153,6 +154,17 @@ pub fn register_fake_definition_server(
                             async move { Ok(result) }
                         }
                     });
+
+                    server.set_request_handler::<lsp::request::GotoTypeDefinition, _, _>({
+                        let index = index.clone();
+                        move |params, _cx| {
+                            let result = index.lock().get_type_definitions(
+                                params.text_document_position_params.text_document.uri,
+                                params.text_document_position_params.position,
+                            );
+                            async move { Ok(result) }
+                        }
+                    });
                 }
             })),
         },
@@ -162,6 +174,7 @@ pub fn register_fake_definition_server(
 struct DefinitionIndex {
     language: Arc<Language>,
     definitions: HashMap<String, Vec<lsp::Location>>,
+    type_annotations: HashMap<String, String>,
     files: HashMap<Uri, FileEntry>,
 }
 
@@ -176,6 +189,7 @@ impl DefinitionIndex {
         Self {
             language,
             definitions: HashMap::default(),
+            type_annotations: HashMap::default(),
             files: HashMap::default(),
         }
     }
@@ -228,6 +242,13 @@ impl DefinitionIndex {
                 .or_insert_with(Vec::new)
                 .push(location);
         }
+
+        for (identifier_name, type_name) in extract_type_annotations(content) {
+            self.type_annotations
+                .entry(identifier_name)
+                .or_insert(type_name);
+        }
+
         self.files.insert(
             uri,
             FileEntry {
@@ -249,6 +270,108 @@ impl DefinitionIndex {
         let locations = self.definitions.get(name).cloned()?;
         Some(lsp::GotoDefinitionResponse::Array(locations))
     }
+
+    fn get_type_definitions(
+        &mut self,
+        uri: Uri,
+        position: lsp::Position,
+    ) -> Option<lsp::GotoDefinitionResponse> {
+        let entry = self.files.get(&uri)?;
+        let name = word_at_position(&entry.contents, position)?;
+
+        if let Some(type_name) = self.type_annotations.get(name) {
+            if let Some(locations) = self.definitions.get(type_name) {
+                return Some(lsp::GotoDefinitionResponse::Array(locations.clone()));
+            }
+        }
+
+        // If the identifier itself is an uppercase name (a type), return its own definition.
+        // This mirrors real LSP behavior where GotoTypeDefinition on a type name
+        // resolves to that type's definition.
+        if name.starts_with(|c: char| c.is_uppercase()) {
+            if let Some(locations) = self.definitions.get(name) {
+                return Some(lsp::GotoDefinitionResponse::Array(locations.clone()));
+            }
+        }
+
+        None
+    }
+}
+
+/// Extracts `identifier_name -> type_name` mappings from field declarations
+/// and function parameters. For example, `owner: Arc<Person>` produces
+/// `"owner" -> "Person"` by unwrapping common generic wrappers.
+fn extract_type_annotations(content: &str) -> Vec<(String, String)> {
+    let mut annotations = Vec::new();
+    for line in content.lines() {
+        let trimmed = line.trim();
+        if trimmed.starts_with("//")
+            || trimmed.starts_with("use ")
+            || trimmed.starts_with("pub use ")
+        {
+            continue;
+        }
+
+        let Some(colon_idx) = trimmed.find(':') else {
+            continue;
+        };
+
+        // The part before `:` should end with an identifier name.
+        let left = trimmed[..colon_idx].trim();
+        let Some(name) = left.split_whitespace().last() else {
+            continue;
+        };
+
+        if name.is_empty() || !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
+            continue;
+        }
+
+        // Skip names that start uppercase — they're type names, not variables/fields.
+        if name.starts_with(|c: char| c.is_uppercase()) {
+            continue;
+        }
+
+        let right = trimmed[colon_idx + 1..].trim();
+        let type_name = extract_base_type_name(right);
+
+        if !type_name.is_empty() && type_name.starts_with(|c: char| c.is_uppercase()) {
+            annotations.push((name.to_string(), type_name));
+        }
+    }
+    annotations
+}
+
+/// Unwraps common generic wrappers (Arc, Box, Rc, Option, Vec) and trait
+/// object prefixes (dyn, impl) to find the concrete type name. For example:
+/// `Arc<Person>` → `"Person"`, `Box<dyn Trait>` → `"Trait"`.
+fn extract_base_type_name(type_str: &str) -> String {
+    let trimmed = type_str
+        .trim()
+        .trim_start_matches('&')
+        .trim_start_matches("mut ")
+        .trim_end_matches(',')
+        .trim_end_matches('{')
+        .trim_end_matches(')')
+        .trim()
+        .trim_start_matches("dyn ")
+        .trim_start_matches("impl ")
+        .trim();
+
+    if let Some(angle_start) = trimmed.find('<') {
+        let outer = &trimmed[..angle_start];
+        if matches!(outer, "Arc" | "Box" | "Rc" | "Option" | "Vec" | "Cow") {
+            let inner_end = trimmed.rfind('>').unwrap_or(trimmed.len());
+            let inner = &trimmed[angle_start + 1..inner_end];
+            return extract_base_type_name(inner);
+        }
+        return outer.to_string();
+    }
+
+    trimmed
+        .split(|c: char| !c.is_alphanumeric() && c != '_')
+        .next()
+        .unwrap_or("")
+        .to_string()
 }
 
 fn extract_declarations_from_tree(