From 6af385235d2dbd145391861556704edb1523ebf3 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Tue, 30 Sep 2025 02:06:31 -0600 Subject: [PATCH] zeta_cli: Add retrieval-stats command for comparing with language server symbol resolution (#39164) Release Notes: - N/A --------- Co-authored-by: Agus --- Cargo.lock | 2 + .../src/declaration.rs | 7 + .../src/declaration_scoring.rs | 9 +- .../src/edit_prediction_context.rs | 25 +- .../edit_prediction_context/src/reference.rs | 16 +- .../src/syntax_index.rs | 21 ++ crates/zeta_cli/Cargo.toml | 2 + crates/zeta_cli/src/main.rs | 341 +++++++++++++++++- 8 files changed, 408 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index aeb82b50beaf83353a12acee782d7541ffb0f6c8..489560796f8b70d1289dc92773653b533d7b12fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20730,7 +20730,9 @@ dependencies = [ "language_model", "language_models", "languages", + "log", "node_runtime", + "ordered-float 2.10.1", "paths", "project", "prompt_store", diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index 910835534af80ba97b99b8fc560c27bf13c4acda..a6efe63fc606580311d6e7653bb5ee98a80fb9d3 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -55,6 +55,13 @@ impl Declaration { } } + pub fn as_file(&self) -> Option<&FileDeclaration> { + match self { + Declaration::Buffer { .. } => None, + Declaration::File { declaration, .. } => Some(declaration), + } + } + pub fn project_entry_id(&self) -> ProjectEntryId { match self { Declaration::File { diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index 363e61cd21e6cf0432d23a0a50619cf420777fd9..6f027ed1f63cdd2688cd149edcc19f7a8fbc704f 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,9 +1,10 @@ use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; +use collections::HashMap; use itertools::Itertools as _; use language::BufferSnapshot; use ordered_float::OrderedFloat; use serde::Serialize; -use std::{cmp::Reverse, collections::HashMap, ops::Range}; +use std::{cmp::Reverse, ops::Range}; use strum::EnumIter; use text::{Point, ToPoint}; @@ -251,6 +252,7 @@ fn score_declaration( pub struct DeclarationScores { pub signature: f32, pub declaration: f32, + pub retrieval: f32, } impl DeclarationScores { @@ -258,7 +260,7 @@ impl DeclarationScores { // TODO: handle truncation // Score related to how likely this is the correct declaration, range 0 to 1 - let accuracy_score = if components.is_same_file { + let retrieval = if components.is_same_file { // TODO: use declaration_line_distance_rank 1.0 / components.same_file_declaration_count as f32 } else { @@ -274,13 +276,14 @@ impl DeclarationScores { }; // For now instead of linear combination, the scores are just multiplied together. - let combined_score = 10.0 * accuracy_score * distance_score; + let combined_score = 10.0 * retrieval * distance_score; DeclarationScores { signature: combined_score * components.excerpt_vs_signature_weighted_overlap, // declaration score gets boosted both by being multiplied by 2 and by there being more // weighted overlap. declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap, + retrieval, } } } diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 46fc11bb7b8c67663c8d73cf3ac51b3d614aaa5f..c994caf7546fdb22539e9d60ff976d4379ed2cc8 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -4,10 +4,11 @@ mod excerpt; mod outline; mod reference; mod syntax_index; -mod text_similarity; +pub mod text_similarity; use std::sync::Arc; +use collections::HashMap; use gpui::{App, AppContext as _, Entity, Task}; use language::BufferSnapshot; use text::{Point, ToOffset as _}; @@ -54,6 +55,26 @@ impl EditPredictionContext { buffer: &BufferSnapshot, excerpt_options: &EditPredictionExcerptOptions, index_state: Option<&SyntaxIndexState>, + ) -> Option { + Self::gather_context_with_references_fn( + cursor_point, + buffer, + excerpt_options, + index_state, + references_in_excerpt, + ) + } + + pub fn gather_context_with_references_fn( + cursor_point: Point, + buffer: &BufferSnapshot, + excerpt_options: &EditPredictionExcerptOptions, + index_state: Option<&SyntaxIndexState>, + get_references: impl FnOnce( + &EditPredictionExcerpt, + &EditPredictionExcerptText, + &BufferSnapshot, + ) -> HashMap>, ) -> Option { let excerpt = EditPredictionExcerpt::select_from_buffer( cursor_point, @@ -77,7 +98,7 @@ impl EditPredictionContext { let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start); let declarations = if let Some(index_state) = index_state { - let references = references_in_excerpt(&excerpt, &excerpt_text, buffer); + let references = get_references(&excerpt, &excerpt_text, buffer); scored_declarations( &index_state, diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs index 268f8c39ef84ba29593f502aff7e818e931cc873..699adf1d8036802a7a4b9e34ca8e8094e4f97458 100644 --- a/crates/edit_prediction_context/src/reference.rs +++ b/crates/edit_prediction_context/src/reference.rs @@ -1,5 +1,5 @@ +use collections::HashMap; use language::BufferSnapshot; -use std::collections::HashMap; use std::ops::Range; use util::RangeExt; @@ -8,7 +8,7 @@ use crate::{ excerpt::{EditPredictionExcerpt, EditPredictionExcerptText}, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Reference { pub identifier: Identifier, pub range: Range, @@ -26,7 +26,7 @@ pub fn references_in_excerpt( excerpt_text: &EditPredictionExcerptText, snapshot: &BufferSnapshot, ) -> HashMap> { - let mut references = identifiers_in_range( + let mut references = references_in_range( excerpt.range.clone(), excerpt_text.body.as_str(), ReferenceRegion::Nearby, @@ -38,7 +38,7 @@ pub fn references_in_excerpt( .iter() .zip(excerpt_text.parent_signatures.iter()) { - references.extend(identifiers_in_range( + references.extend(references_in_range( range.clone(), text.as_str(), ReferenceRegion::Breadcrumb, @@ -46,7 +46,7 @@ pub fn references_in_excerpt( )); } - let mut identifier_to_references: HashMap> = HashMap::new(); + let mut identifier_to_references: HashMap> = HashMap::default(); for reference in references { identifier_to_references .entry(reference.identifier.clone()) @@ -57,7 +57,7 @@ pub fn references_in_excerpt( } /// Finds all nodes which have a "variable" match from the highlights query within the offset range. -pub fn identifiers_in_range( +pub fn references_in_range( range: Range, range_text: &str, reference_region: ReferenceRegion, @@ -120,7 +120,7 @@ mod test { use indoc::indoc; use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - use crate::reference::{ReferenceRegion, identifiers_in_range}; + use crate::reference::{ReferenceRegion, references_in_range}; #[gpui::test] fn test_identifier_node_truncated(cx: &mut TestAppContext) { @@ -136,7 +136,7 @@ mod test { let buffer = create_buffer(code, cx); let range = 0..35; - let references = identifiers_in_range( + let references = references_in_range( range.clone(), &code[range], ReferenceRegion::Breadcrumb, diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index c2b9b0540484071bcc145572b1a703baf889b7e9..5ebd53316d97064754ec688516593584f508797b 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -229,6 +229,27 @@ impl SyntaxIndex { } } + pub fn indexed_file_paths(&self, cx: &App) -> Task> { + let state = self.state.clone(); + let project = self.project.clone(); + + cx.spawn(async move |cx| { + let state = state.lock().await; + let Some(project) = project.upgrade() else { + return vec![]; + }; + project + .read_with(cx, |project, cx| { + state + .files + .keys() + .filter_map(|entry_id| project.path_for_entry(*entry_id, cx)) + .collect() + }) + .unwrap_or_default() + }) + } + fn handle_worktree_store_event( &mut self, _worktree_store: Entity, diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index 87e8d10ea56eb8dd378a1d56defcb6cd952436d8..660de610c14ae3926b787e136d3aa9779156c279 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -30,6 +30,7 @@ language_extension.workspace = true language_model.workspace = true language_models.workspace = true languages = { workspace = true, features = ["load-grammars"] } +log.workspace = true node_runtime.workspace = true paths.workspace = true project.workspace = true @@ -48,3 +49,4 @@ workspace-hack.workspace = true zeta.workspace = true zeta2.workspace = true zlog.workspace = true +ordered-float.workspace = true diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 599b025659f15e5349cfe2e9d2821342a6c0e72a..ebe59fc7a374202ee5611965d82a1971abb30354 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -3,19 +3,27 @@ mod headless; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand}; use cloud_llm_client::predict_edits_v3; -use edit_prediction_context::EditPredictionExcerptOptions; +use edit_prediction_context::{ + Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion, + SyntaxIndex, references_in_range, +}; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, Application, AsyncApp}; use gpui::{Entity, Task}; use language::Bias; -use language::Buffer; use language::Point; +use language::{Buffer, OffsetRangeExt}; use language_model::LlmApiToken; +use ordered_float::OrderedFloat; use project::{Project, ProjectPath, Worktree}; use release_channel::AppVersion; use reqwest_client::ReqwestClient; use serde_json::json; +use std::cmp::Reverse; +use std::collections::HashMap; +use std::io::Write as _; +use std::ops::Range; use std::path::{Path, PathBuf}; use std::process::exit; use std::str::FromStr; @@ -23,6 +31,7 @@ use std::sync::Arc; use std::time::Duration; use util::paths::PathStyle; use util::rel_path::RelPath; +use util::{RangeExt, ResultExt as _}; use zeta::{PerformPredictEditsParams, Zeta}; use crate::headless::ZetaCliAppState; @@ -49,6 +58,12 @@ enum Commands { #[clap(flatten)] context_args: Option, }, + RetrievalStats { + #[arg(long)] + worktree: PathBuf, + #[arg(long, default_value_t = 42)] + file_indexing_parallelism: usize, + }, } #[derive(Debug, Args)] @@ -316,6 +331,312 @@ async fn get_context( } } +pub async fn retrieval_stats( + worktree: PathBuf, + file_indexing_parallelism: usize, + app_state: Arc, + cx: &mut AsyncApp, +) -> Result { + let worktree_path = worktree.canonicalize()?; + + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(&worktree_path, true, cx) + })? + .await?; + let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?; + + // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree. + worktree + .read_with(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?; + index + .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))? + .await?; + let files = index + .read_with(cx, |index, cx| index.indexed_file_paths(cx))? + .await; + + let mut lsp_open_handles = Vec::new(); + let mut output = std::fs::File::create("retrieval-stats.txt")?; + let mut results = Vec::new(); + for (file_index, project_path) in files.iter().enumerate() { + println!( + "Processing file {} of {}: {}", + file_index + 1, + files.len(), + project_path.path.display(PathStyle::Posix) + ); + let Some((lsp_open_handle, buffer)) = + open_buffer_with_language_server(&project, &worktree, &project_path.path, cx) + .await + .log_err() + else { + continue; + }; + lsp_open_handles.push(lsp_open_handle); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let full_range = 0..snapshot.len(); + let references = references_in_range( + full_range, + &snapshot.text(), + ReferenceRegion::Nearby, + &snapshot, + ); + + let index = index.read_with(cx, |index, _cx| index.state().clone())?; + let index = index.lock().await; + for reference in references { + let query_point = snapshot.offset_to_point(reference.range.start); + let mut single_reference_map = HashMap::default(); + single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]); + let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn( + query_point, + &snapshot, + &zeta2::DEFAULT_EXCERPT_OPTIONS, + Some(&index), + |_, _, _| single_reference_map, + ); + + let Some(edit_prediction_context) = edit_prediction_context else { + let result = RetrievalStatsResult { + identifier: reference.identifier, + point: query_point, + outcome: RetrievalStatsOutcome::NoExcerpt, + }; + write!(output, "{:?}\n\n", result)?; + results.push(result); + continue; + }; + + let mut retrieved_definitions = Vec::new(); + for scored_declaration in edit_prediction_context.declarations { + match &scored_declaration.declaration { + Declaration::File { + project_entry_id, + declaration, + } => { + let Some(path) = worktree.read_with(cx, |worktree, _cx| { + worktree + .entry_for_id(*project_entry_id) + .map(|entry| entry.path.clone()) + })? + else { + log::error!("bug: file project entry not found"); + continue; + }; + let project_path = ProjectPath { + worktree_id, + path: path.clone(), + }; + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?; + retrieved_definitions.push(( + path, + rope.offset_to_point(declaration.item_range.start) + ..rope.offset_to_point(declaration.item_range.end), + scored_declaration.scores.declaration, + scored_declaration.scores.retrieval, + )); + } + Declaration::Buffer { + project_entry_id, + rope, + declaration, + .. + } => { + let Some(path) = worktree.read_with(cx, |worktree, _cx| { + worktree + .entry_for_id(*project_entry_id) + .map(|entry| entry.path.clone()) + })? + else { + log::error!("bug: buffer project entry not found"); + continue; + }; + retrieved_definitions.push(( + path, + rope.offset_to_point(declaration.item_range.start) + ..rope.offset_to_point(declaration.item_range.end), + scored_declaration.scores.declaration, + scored_declaration.scores.retrieval, + )); + } + } + } + retrieved_definitions + .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score))); + + // TODO: Consider still checking language server in this case, or having a mode for + // this. For now assuming that the purpose of this is to refine the ranking rather than + // refining whether the definition is present at all. + if retrieved_definitions.is_empty() { + continue; + } + + // TODO: Rename declaration to definition in edit_prediction_context? + let lsp_result = project + .update(cx, |project, cx| { + project.definitions(&buffer, reference.range.start, cx) + })? + .await; + match lsp_result { + Ok(lsp_definitions) => { + let lsp_definitions = lsp_definitions + .unwrap_or_default() + .into_iter() + .filter_map(|definition| { + definition + .target + .buffer + .read_with(cx, |buffer, _cx| { + Some(( + buffer.file()?.path().clone(), + definition.target.range.to_point(&buffer), + )) + }) + .ok()? + }) + .collect::>(); + + let result = RetrievalStatsResult { + identifier: reference.identifier, + point: query_point, + outcome: RetrievalStatsOutcome::Success { + matches: lsp_definitions + .iter() + .map(|(path, range)| { + retrieved_definitions.iter().position( + |(retrieved_path, retrieved_range, _, _)| { + path == retrieved_path + && retrieved_range.contains_inclusive(&range) + }, + ) + }) + .collect(), + lsp_definitions, + retrieved_definitions, + }, + }; + write!(output, "{:?}\n\n", result)?; + results.push(result); + } + Err(err) => { + let result = RetrievalStatsResult { + identifier: reference.identifier, + point: query_point, + outcome: RetrievalStatsOutcome::LanguageServerError { + message: err.to_string(), + }, + }; + write!(output, "{:?}\n\n", result)?; + results.push(result); + } + } + } + } + + let mut no_excerpt_count = 0; + let mut error_count = 0; + let mut definitions_count = 0; + let mut top_match_count = 0; + let mut non_top_match_count = 0; + let mut ranking_involved_count = 0; + let mut ranking_involved_top_match_count = 0; + let mut ranking_involved_non_top_match_count = 0; + for result in &results { + match &result.outcome { + RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1, + RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1, + RetrievalStatsOutcome::Success { + matches, + retrieved_definitions, + .. + } => { + definitions_count += 1; + let top_matches = matches.contains(&Some(0)); + if top_matches { + top_match_count += 1; + } + let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0)); + if non_top_matches { + non_top_match_count += 1; + } + if retrieved_definitions.len() > 1 { + ranking_involved_count += 1; + if top_matches { + ranking_involved_top_match_count += 1; + } + if non_top_matches { + ranking_involved_non_top_match_count += 1; + } + } + } + } + } + + println!("\nStats:\n"); + println!("No Excerpt: {}", no_excerpt_count); + println!("Language Server Error: {}", error_count); + println!("Definitions: {}", definitions_count); + println!("Top Match: {}", top_match_count); + println!("Non-Top Match: {}", non_top_match_count); + println!("Ranking Involved: {}", ranking_involved_count); + println!( + "Ranking Involved Top Match: {}", + ranking_involved_top_match_count + ); + println!( + "Ranking Involved Non-Top Match: {}", + ranking_involved_non_top_match_count + ); + + Ok("".to_string()) +} + +#[derive(Debug)] +struct RetrievalStatsResult { + #[allow(dead_code)] + identifier: Identifier, + #[allow(dead_code)] + point: Point, + outcome: RetrievalStatsOutcome, +} + +#[derive(Debug)] +enum RetrievalStatsOutcome { + NoExcerpt, + LanguageServerError { + #[allow(dead_code)] + message: String, + }, + Success { + matches: Vec>, + #[allow(dead_code)] + lsp_definitions: Vec<(Arc, Range)>, + retrieved_definitions: Vec<(Arc, Range, f32, f32)>, + }, +} + pub async fn open_buffer( project: &Entity, worktree: &Entity, @@ -385,6 +706,7 @@ pub fn wait_for_lang_server( .unwrap() .detach(); } + let (mut added_tx, mut added_rx) = mpsc::channel(1); let subscriptions = [ cx.subscribe(&lsp_store, { @@ -413,6 +735,7 @@ pub fn wait_for_lang_server( project .update(cx, |project, cx| project.save_buffer(buffer, cx)) .detach(); + added_tx.try_send(()).ok(); } project::Event::DiskBasedDiagnosticsFinished { .. } => { tx.try_send(()).ok(); @@ -423,6 +746,16 @@ pub fn wait_for_lang_server( ]; cx.spawn(async move |cx| { + if !has_lang_server { + // some buffers never have a language server, so this aborts quickly in that case. + let timeout = cx.background_executor().timer(Duration::from_secs(1)); + futures::select! { + _ = added_rx.next() => {}, + _ = timeout.fuse() => { + anyhow::bail!("Waiting for language server add timed out after 1 second"); + } + }; + } let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5)); let result = futures::select! { _ = rx.next() => { @@ -504,6 +837,10 @@ fn main() { }) .await } + Commands::RetrievalStats { + worktree, + file_indexing_parallelism, + } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await, }; match result { Ok(output) => {