zeta_cli: Add retrieval-stats command for comparing with language server symbol resolution (#39164)

Michael Sloan and Agus created

Release Notes:

- N/A

---------

Co-authored-by: Agus <agus@zed.dev>

Change summary

Cargo.lock                                                    |   2 
crates/edit_prediction_context/src/declaration.rs             |   7 
crates/edit_prediction_context/src/declaration_scoring.rs     |   9 
crates/edit_prediction_context/src/edit_prediction_context.rs |  25 
crates/edit_prediction_context/src/reference.rs               |  16 
crates/edit_prediction_context/src/syntax_index.rs            |  21 
crates/zeta_cli/Cargo.toml                                    |   2 
crates/zeta_cli/src/main.rs                                   | 341 ++++
8 files changed, 408 insertions(+), 15 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -20730,7 +20730,9 @@ dependencies = [
  "language_model",
  "language_models",
  "languages",
+ "log",
  "node_runtime",
+ "ordered-float 2.10.1",
  "paths",
  "project",
  "prompt_store",

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -55,6 +55,13 @@ impl Declaration {
         }
     }
 
+    pub fn as_file(&self) -> Option<&FileDeclaration> {
+        match self {
+            Declaration::Buffer { .. } => None,
+            Declaration::File { declaration, .. } => Some(declaration),
+        }
+    }
+
     pub fn project_entry_id(&self) -> ProjectEntryId {
         match self {
             Declaration::File {

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -1,9 +1,10 @@
 use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
+use collections::HashMap;
 use itertools::Itertools as _;
 use language::BufferSnapshot;
 use ordered_float::OrderedFloat;
 use serde::Serialize;
-use std::{cmp::Reverse, collections::HashMap, ops::Range};
+use std::{cmp::Reverse, ops::Range};
 use strum::EnumIter;
 use text::{Point, ToPoint};
 
@@ -251,6 +252,7 @@ fn score_declaration(
 pub struct DeclarationScores {
     pub signature: f32,
     pub declaration: f32,
+    pub retrieval: f32,
 }
 
 impl DeclarationScores {
@@ -258,7 +260,7 @@ impl DeclarationScores {
         // TODO: handle truncation
 
         // Score related to how likely this is the correct declaration, range 0 to 1
-        let accuracy_score = if components.is_same_file {
+        let retrieval = if components.is_same_file {
             // TODO: use declaration_line_distance_rank
             1.0 / components.same_file_declaration_count as f32
         } else {
@@ -274,13 +276,14 @@ impl DeclarationScores {
         };
 
         // For now instead of linear combination, the scores are just multiplied together.
-        let combined_score = 10.0 * accuracy_score * distance_score;
+        let combined_score = 10.0 * retrieval * distance_score;
 
         DeclarationScores {
             signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
             // declaration score gets boosted both by being multiplied by 2 and by there being more
             // weighted overlap.
             declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
+            retrieval,
         }
     }
 }

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -4,10 +4,11 @@ mod excerpt;
 mod outline;
 mod reference;
 mod syntax_index;
-mod text_similarity;
+pub mod text_similarity;
 
 use std::sync::Arc;
 
+use collections::HashMap;
 use gpui::{App, AppContext as _, Entity, Task};
 use language::BufferSnapshot;
 use text::{Point, ToOffset as _};
@@ -54,6 +55,26 @@ impl EditPredictionContext {
         buffer: &BufferSnapshot,
         excerpt_options: &EditPredictionExcerptOptions,
         index_state: Option<&SyntaxIndexState>,
+    ) -> Option<Self> {
+        Self::gather_context_with_references_fn(
+            cursor_point,
+            buffer,
+            excerpt_options,
+            index_state,
+            references_in_excerpt,
+        )
+    }
+
+    pub fn gather_context_with_references_fn(
+        cursor_point: Point,
+        buffer: &BufferSnapshot,
+        excerpt_options: &EditPredictionExcerptOptions,
+        index_state: Option<&SyntaxIndexState>,
+        get_references: impl FnOnce(
+            &EditPredictionExcerpt,
+            &EditPredictionExcerptText,
+            &BufferSnapshot,
+        ) -> HashMap<Identifier, Vec<Reference>>,
     ) -> Option<Self> {
         let excerpt = EditPredictionExcerpt::select_from_buffer(
             cursor_point,
@@ -77,7 +98,7 @@ impl EditPredictionContext {
         let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 
         let declarations = if let Some(index_state) = index_state {
-            let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
+            let references = get_references(&excerpt, &excerpt_text, buffer);
 
             scored_declarations(
                 &index_state,

crates/edit_prediction_context/src/reference.rs 🔗

@@ -1,5 +1,5 @@
+use collections::HashMap;
 use language::BufferSnapshot;
-use std::collections::HashMap;
 use std::ops::Range;
 use util::RangeExt;
 
@@ -8,7 +8,7 @@ use crate::{
     excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
 };
 
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 pub struct Reference {
     pub identifier: Identifier,
     pub range: Range<usize>,
@@ -26,7 +26,7 @@ pub fn references_in_excerpt(
     excerpt_text: &EditPredictionExcerptText,
     snapshot: &BufferSnapshot,
 ) -> HashMap<Identifier, Vec<Reference>> {
-    let mut references = identifiers_in_range(
+    let mut references = references_in_range(
         excerpt.range.clone(),
         excerpt_text.body.as_str(),
         ReferenceRegion::Nearby,
@@ -38,7 +38,7 @@ pub fn references_in_excerpt(
         .iter()
         .zip(excerpt_text.parent_signatures.iter())
     {
-        references.extend(identifiers_in_range(
+        references.extend(references_in_range(
             range.clone(),
             text.as_str(),
             ReferenceRegion::Breadcrumb,
@@ -46,7 +46,7 @@ pub fn references_in_excerpt(
         ));
     }
 
-    let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
+    let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::default();
     for reference in references {
         identifier_to_references
             .entry(reference.identifier.clone())
@@ -57,7 +57,7 @@ pub fn references_in_excerpt(
 }
 
 /// Finds all nodes which have a "variable" match from the highlights query within the offset range.
-pub fn identifiers_in_range(
+pub fn references_in_range(
     range: Range<usize>,
     range_text: &str,
     reference_region: ReferenceRegion,
@@ -120,7 +120,7 @@ mod test {
     use indoc::indoc;
     use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
 
-    use crate::reference::{ReferenceRegion, identifiers_in_range};
+    use crate::reference::{ReferenceRegion, references_in_range};
 
     #[gpui::test]
     fn test_identifier_node_truncated(cx: &mut TestAppContext) {
@@ -136,7 +136,7 @@ mod test {
         let buffer = create_buffer(code, cx);
 
         let range = 0..35;
-        let references = identifiers_in_range(
+        let references = references_in_range(
             range.clone(),
             &code[range],
             ReferenceRegion::Breadcrumb,

crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -229,6 +229,27 @@ impl SyntaxIndex {
         }
     }
 
+    pub fn indexed_file_paths(&self, cx: &App) -> Task<Vec<ProjectPath>> {
+        let state = self.state.clone();
+        let project = self.project.clone();
+
+        cx.spawn(async move |cx| {
+            let state = state.lock().await;
+            let Some(project) = project.upgrade() else {
+                return vec![];
+            };
+            project
+                .read_with(cx, |project, cx| {
+                    state
+                        .files
+                        .keys()
+                        .filter_map(|entry_id| project.path_for_entry(*entry_id, cx))
+                        .collect()
+                })
+                .unwrap_or_default()
+        })
+    }
+
     fn handle_worktree_store_event(
         &mut self,
         _worktree_store: Entity<WorktreeStore>,

crates/zeta_cli/Cargo.toml 🔗

@@ -30,6 +30,7 @@ language_extension.workspace = true
 language_model.workspace = true
 language_models.workspace = true
 languages = { workspace = true, features = ["load-grammars"] }
+log.workspace = true
 node_runtime.workspace = true
 paths.workspace = true
 project.workspace = true
@@ -48,3 +49,4 @@ workspace-hack.workspace = true
 zeta.workspace = true
 zeta2.workspace = true
 zlog.workspace = true
+ordered-float.workspace = true

crates/zeta_cli/src/main.rs 🔗

@@ -3,19 +3,27 @@ mod headless;
 use anyhow::{Result, anyhow};
 use clap::{Args, Parser, Subcommand};
 use cloud_llm_client::predict_edits_v3;
-use edit_prediction_context::EditPredictionExcerptOptions;
+use edit_prediction_context::{
+    Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion,
+    SyntaxIndex, references_in_range,
+};
 use futures::channel::mpsc;
 use futures::{FutureExt as _, StreamExt as _};
 use gpui::{AppContext, Application, AsyncApp};
 use gpui::{Entity, Task};
 use language::Bias;
-use language::Buffer;
 use language::Point;
+use language::{Buffer, OffsetRangeExt};
 use language_model::LlmApiToken;
+use ordered_float::OrderedFloat;
 use project::{Project, ProjectPath, Worktree};
 use release_channel::AppVersion;
 use reqwest_client::ReqwestClient;
 use serde_json::json;
+use std::cmp::Reverse;
+use std::collections::HashMap;
+use std::io::Write as _;
+use std::ops::Range;
 use std::path::{Path, PathBuf};
 use std::process::exit;
 use std::str::FromStr;
@@ -23,6 +31,7 @@ use std::sync::Arc;
 use std::time::Duration;
 use util::paths::PathStyle;
 use util::rel_path::RelPath;
+use util::{RangeExt, ResultExt as _};
 use zeta::{PerformPredictEditsParams, Zeta};
 
 use crate::headless::ZetaCliAppState;
@@ -49,6 +58,12 @@ enum Commands {
         #[clap(flatten)]
         context_args: Option<ContextArgs>,
     },
+    RetrievalStats {
+        #[arg(long)]
+        worktree: PathBuf,
+        #[arg(long, default_value_t = 42)]
+        file_indexing_parallelism: usize,
+    },
 }
 
 #[derive(Debug, Args)]
@@ -316,6 +331,312 @@ async fn get_context(
     }
 }
 
+pub async fn retrieval_stats(
+    worktree: PathBuf,
+    file_indexing_parallelism: usize,
+    app_state: Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) -> Result<String> {
+    let worktree_path = worktree.canonicalize()?;
+
+    let project = cx.update(|cx| {
+        Project::local(
+            app_state.client.clone(),
+            app_state.node_runtime.clone(),
+            app_state.user_store.clone(),
+            app_state.languages.clone(),
+            app_state.fs.clone(),
+            None,
+            cx,
+        )
+    })?;
+
+    let worktree = project
+        .update(cx, |project, cx| {
+            project.create_worktree(&worktree_path, true, cx)
+        })?
+        .await?;
+    let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
+
+    // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
+    worktree
+        .read_with(cx, |worktree, _cx| {
+            worktree.as_local().unwrap().scan_complete()
+        })?
+        .await;
+
+    let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?;
+    index
+        .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
+        .await?;
+    let files = index
+        .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
+        .await;
+
+    let mut lsp_open_handles = Vec::new();
+    let mut output = std::fs::File::create("retrieval-stats.txt")?;
+    let mut results = Vec::new();
+    for (file_index, project_path) in files.iter().enumerate() {
+        println!(
+            "Processing file {} of {}: {}",
+            file_index + 1,
+            files.len(),
+            project_path.path.display(PathStyle::Posix)
+        );
+        let Some((lsp_open_handle, buffer)) =
+            open_buffer_with_language_server(&project, &worktree, &project_path.path, cx)
+                .await
+                .log_err()
+        else {
+            continue;
+        };
+        lsp_open_handles.push(lsp_open_handle);
+
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+        let full_range = 0..snapshot.len();
+        let references = references_in_range(
+            full_range,
+            &snapshot.text(),
+            ReferenceRegion::Nearby,
+            &snapshot,
+        );
+
+        let index = index.read_with(cx, |index, _cx| index.state().clone())?;
+        let index = index.lock().await;
+        for reference in references {
+            let query_point = snapshot.offset_to_point(reference.range.start);
+            let mut single_reference_map = HashMap::default();
+            single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
+            let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
+                query_point,
+                &snapshot,
+                &zeta2::DEFAULT_EXCERPT_OPTIONS,
+                Some(&index),
+                |_, _, _| single_reference_map,
+            );
+
+            let Some(edit_prediction_context) = edit_prediction_context else {
+                let result = RetrievalStatsResult {
+                    identifier: reference.identifier,
+                    point: query_point,
+                    outcome: RetrievalStatsOutcome::NoExcerpt,
+                };
+                write!(output, "{:?}\n\n", result)?;
+                results.push(result);
+                continue;
+            };
+
+            let mut retrieved_definitions = Vec::new();
+            for scored_declaration in edit_prediction_context.declarations {
+                match &scored_declaration.declaration {
+                    Declaration::File {
+                        project_entry_id,
+                        declaration,
+                    } => {
+                        let Some(path) = worktree.read_with(cx, |worktree, _cx| {
+                            worktree
+                                .entry_for_id(*project_entry_id)
+                                .map(|entry| entry.path.clone())
+                        })?
+                        else {
+                            log::error!("bug: file project entry not found");
+                            continue;
+                        };
+                        let project_path = ProjectPath {
+                            worktree_id,
+                            path: path.clone(),
+                        };
+                        let buffer = project
+                            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+                            .await?;
+                        let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?;
+                        retrieved_definitions.push((
+                            path,
+                            rope.offset_to_point(declaration.item_range.start)
+                                ..rope.offset_to_point(declaration.item_range.end),
+                            scored_declaration.scores.declaration,
+                            scored_declaration.scores.retrieval,
+                        ));
+                    }
+                    Declaration::Buffer {
+                        project_entry_id,
+                        rope,
+                        declaration,
+                        ..
+                    } => {
+                        let Some(path) = worktree.read_with(cx, |worktree, _cx| {
+                            worktree
+                                .entry_for_id(*project_entry_id)
+                                .map(|entry| entry.path.clone())
+                        })?
+                        else {
+                            log::error!("bug: buffer project entry not found");
+                            continue;
+                        };
+                        retrieved_definitions.push((
+                            path,
+                            rope.offset_to_point(declaration.item_range.start)
+                                ..rope.offset_to_point(declaration.item_range.end),
+                            scored_declaration.scores.declaration,
+                            scored_declaration.scores.retrieval,
+                        ));
+                    }
+                }
+            }
+            retrieved_definitions
+                .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score)));
+
+            // TODO: Consider still checking language server in this case, or having a mode for
+            // this. For now assuming that the purpose of this is to refine the ranking rather than
+            // refining whether the definition is present at all.
+            if retrieved_definitions.is_empty() {
+                continue;
+            }
+
+            // TODO: Rename declaration to definition in edit_prediction_context?
+            let lsp_result = project
+                .update(cx, |project, cx| {
+                    project.definitions(&buffer, reference.range.start, cx)
+                })?
+                .await;
+            match lsp_result {
+                Ok(lsp_definitions) => {
+                    let lsp_definitions = lsp_definitions
+                        .unwrap_or_default()
+                        .into_iter()
+                        .filter_map(|definition| {
+                            definition
+                                .target
+                                .buffer
+                                .read_with(cx, |buffer, _cx| {
+                                    Some((
+                                        buffer.file()?.path().clone(),
+                                        definition.target.range.to_point(&buffer),
+                                    ))
+                                })
+                                .ok()?
+                        })
+                        .collect::<Vec<_>>();
+
+                    let result = RetrievalStatsResult {
+                        identifier: reference.identifier,
+                        point: query_point,
+                        outcome: RetrievalStatsOutcome::Success {
+                            matches: lsp_definitions
+                                .iter()
+                                .map(|(path, range)| {
+                                    retrieved_definitions.iter().position(
+                                        |(retrieved_path, retrieved_range, _, _)| {
+                                            path == retrieved_path
+                                                && retrieved_range.contains_inclusive(&range)
+                                        },
+                                    )
+                                })
+                                .collect(),
+                            lsp_definitions,
+                            retrieved_definitions,
+                        },
+                    };
+                    write!(output, "{:?}\n\n", result)?;
+                    results.push(result);
+                }
+                Err(err) => {
+                    let result = RetrievalStatsResult {
+                        identifier: reference.identifier,
+                        point: query_point,
+                        outcome: RetrievalStatsOutcome::LanguageServerError {
+                            message: err.to_string(),
+                        },
+                    };
+                    write!(output, "{:?}\n\n", result)?;
+                    results.push(result);
+                }
+            }
+        }
+    }
+
+    let mut no_excerpt_count = 0;
+    let mut error_count = 0;
+    let mut definitions_count = 0;
+    let mut top_match_count = 0;
+    let mut non_top_match_count = 0;
+    let mut ranking_involved_count = 0;
+    let mut ranking_involved_top_match_count = 0;
+    let mut ranking_involved_non_top_match_count = 0;
+    for result in &results {
+        match &result.outcome {
+            RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1,
+            RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1,
+            RetrievalStatsOutcome::Success {
+                matches,
+                retrieved_definitions,
+                ..
+            } => {
+                definitions_count += 1;
+                let top_matches = matches.contains(&Some(0));
+                if top_matches {
+                    top_match_count += 1;
+                }
+                let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0));
+                if non_top_matches {
+                    non_top_match_count += 1;
+                }
+                if retrieved_definitions.len() > 1 {
+                    ranking_involved_count += 1;
+                    if top_matches {
+                        ranking_involved_top_match_count += 1;
+                    }
+                    if non_top_matches {
+                        ranking_involved_non_top_match_count += 1;
+                    }
+                }
+            }
+        }
+    }
+
+    println!("\nStats:\n");
+    println!("No Excerpt: {}", no_excerpt_count);
+    println!("Language Server Error: {}", error_count);
+    println!("Definitions: {}", definitions_count);
+    println!("Top Match: {}", top_match_count);
+    println!("Non-Top Match: {}", non_top_match_count);
+    println!("Ranking Involved: {}", ranking_involved_count);
+    println!(
+        "Ranking Involved Top Match: {}",
+        ranking_involved_top_match_count
+    );
+    println!(
+        "Ranking Involved Non-Top Match: {}",
+        ranking_involved_non_top_match_count
+    );
+
+    Ok("".to_string())
+}
+
+#[derive(Debug)]
+struct RetrievalStatsResult {
+    #[allow(dead_code)]
+    identifier: Identifier,
+    #[allow(dead_code)]
+    point: Point,
+    outcome: RetrievalStatsOutcome,
+}
+
+#[derive(Debug)]
+enum RetrievalStatsOutcome {
+    NoExcerpt,
+    LanguageServerError {
+        #[allow(dead_code)]
+        message: String,
+    },
+    Success {
+        matches: Vec<Option<usize>>,
+        #[allow(dead_code)]
+        lsp_definitions: Vec<(Arc<RelPath>, Range<Point>)>,
+        retrieved_definitions: Vec<(Arc<RelPath>, Range<Point>, f32, f32)>,
+    },
+}
+
 pub async fn open_buffer(
     project: &Entity<Project>,
     worktree: &Entity<Worktree>,
@@ -385,6 +706,7 @@ pub fn wait_for_lang_server(
             .unwrap()
             .detach();
     }
+    let (mut added_tx, mut added_rx) = mpsc::channel(1);
 
     let subscriptions = [
         cx.subscribe(&lsp_store, {
@@ -413,6 +735,7 @@ pub fn wait_for_lang_server(
                     project
                         .update(cx, |project, cx| project.save_buffer(buffer, cx))
                         .detach();
+                    added_tx.try_send(()).ok();
                 }
                 project::Event::DiskBasedDiagnosticsFinished { .. } => {
                     tx.try_send(()).ok();
@@ -423,6 +746,16 @@ pub fn wait_for_lang_server(
     ];
 
     cx.spawn(async move |cx| {
+        if !has_lang_server {
+            // some buffers never have a language server, so this aborts quickly in that case.
+            let timeout = cx.background_executor().timer(Duration::from_secs(1));
+            futures::select! {
+                _ = added_rx.next() => {},
+                _ = timeout.fuse() => {
+                    anyhow::bail!("Waiting for language server add timed out after 1 second");
+                }
+            };
+        }
         let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
         let result = futures::select! {
             _ = rx.next() => {
@@ -504,6 +837,10 @@ fn main() {
                     })
                     .await
                 }
+                Commands::RetrievalStats {
+                    worktree,
+                    file_indexing_parallelism,
+                } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await,
             };
             match result {
                 Ok(output) => {