Cargo.lock 🔗
@@ -20730,7 +20730,9 @@ dependencies = [
"language_model",
"language_models",
"languages",
+ "log",
"node_runtime",
+ "ordered-float 2.10.1",
"paths",
"project",
"prompt_store",
Michael Sloan and Agus created
Release Notes:
- N/A
---------
Co-authored-by: Agus <agus@zed.dev>
Cargo.lock | 2
crates/edit_prediction_context/src/declaration.rs | 7
crates/edit_prediction_context/src/declaration_scoring.rs | 9
crates/edit_prediction_context/src/edit_prediction_context.rs | 25
crates/edit_prediction_context/src/reference.rs | 16
crates/edit_prediction_context/src/syntax_index.rs | 21
crates/zeta_cli/Cargo.toml | 2
crates/zeta_cli/src/main.rs | 341 ++++
8 files changed, 408 insertions(+), 15 deletions(-)
@@ -20730,7 +20730,9 @@ dependencies = [
"language_model",
"language_models",
"languages",
+ "log",
"node_runtime",
+ "ordered-float 2.10.1",
"paths",
"project",
"prompt_store",
@@ -55,6 +55,13 @@ impl Declaration {
}
}
+ pub fn as_file(&self) -> Option<&FileDeclaration> {
+ match self {
+ Declaration::Buffer { .. } => None,
+ Declaration::File { declaration, .. } => Some(declaration),
+ }
+ }
+
pub fn project_entry_id(&self) -> ProjectEntryId {
match self {
Declaration::File {
@@ -1,9 +1,10 @@
use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
+use collections::HashMap;
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
use serde::Serialize;
-use std::{cmp::Reverse, collections::HashMap, ops::Range};
+use std::{cmp::Reverse, ops::Range};
use strum::EnumIter;
use text::{Point, ToPoint};
@@ -251,6 +252,7 @@ fn score_declaration(
pub struct DeclarationScores {
pub signature: f32,
pub declaration: f32,
+ pub retrieval: f32,
}
impl DeclarationScores {
@@ -258,7 +260,7 @@ impl DeclarationScores {
// TODO: handle truncation
// Score related to how likely this is the correct declaration, range 0 to 1
- let accuracy_score = if components.is_same_file {
+ let retrieval = if components.is_same_file {
// TODO: use declaration_line_distance_rank
1.0 / components.same_file_declaration_count as f32
} else {
@@ -274,13 +276,14 @@ impl DeclarationScores {
};
// For now instead of linear combination, the scores are just multiplied together.
- let combined_score = 10.0 * accuracy_score * distance_score;
+ let combined_score = 10.0 * retrieval * distance_score;
DeclarationScores {
signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multiplied by 2 and by there being more
// weighted overlap.
declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
+ retrieval,
}
}
}
@@ -4,10 +4,11 @@ mod excerpt;
mod outline;
mod reference;
mod syntax_index;
-mod text_similarity;
+pub mod text_similarity;
use std::sync::Arc;
+use collections::HashMap;
use gpui::{App, AppContext as _, Entity, Task};
use language::BufferSnapshot;
use text::{Point, ToOffset as _};
@@ -54,6 +55,26 @@ impl EditPredictionContext {
buffer: &BufferSnapshot,
excerpt_options: &EditPredictionExcerptOptions,
index_state: Option<&SyntaxIndexState>,
+ ) -> Option<Self> {
+ Self::gather_context_with_references_fn(
+ cursor_point,
+ buffer,
+ excerpt_options,
+ index_state,
+ references_in_excerpt,
+ )
+ }
+
+ pub fn gather_context_with_references_fn(
+ cursor_point: Point,
+ buffer: &BufferSnapshot,
+ excerpt_options: &EditPredictionExcerptOptions,
+ index_state: Option<&SyntaxIndexState>,
+ get_references: impl FnOnce(
+ &EditPredictionExcerpt,
+ &EditPredictionExcerptText,
+ &BufferSnapshot,
+ ) -> HashMap<Identifier, Vec<Reference>>,
) -> Option<Self> {
let excerpt = EditPredictionExcerpt::select_from_buffer(
cursor_point,
@@ -77,7 +98,7 @@ impl EditPredictionContext {
let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
let declarations = if let Some(index_state) = index_state {
- let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
+ let references = get_references(&excerpt, &excerpt_text, buffer);
scored_declarations(
&index_state,
@@ -1,5 +1,5 @@
+use collections::HashMap;
use language::BufferSnapshot;
-use std::collections::HashMap;
use std::ops::Range;
use util::RangeExt;
@@ -8,7 +8,7 @@ use crate::{
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
};
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Reference {
pub identifier: Identifier,
pub range: Range<usize>,
@@ -26,7 +26,7 @@ pub fn references_in_excerpt(
excerpt_text: &EditPredictionExcerptText,
snapshot: &BufferSnapshot,
) -> HashMap<Identifier, Vec<Reference>> {
- let mut references = identifiers_in_range(
+ let mut references = references_in_range(
excerpt.range.clone(),
excerpt_text.body.as_str(),
ReferenceRegion::Nearby,
@@ -38,7 +38,7 @@ pub fn references_in_excerpt(
.iter()
.zip(excerpt_text.parent_signatures.iter())
{
- references.extend(identifiers_in_range(
+ references.extend(references_in_range(
range.clone(),
text.as_str(),
ReferenceRegion::Breadcrumb,
@@ -46,7 +46,7 @@ pub fn references_in_excerpt(
));
}
- let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
+ let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::default();
for reference in references {
identifier_to_references
.entry(reference.identifier.clone())
@@ -57,7 +57,7 @@ pub fn references_in_excerpt(
}
/// Finds all nodes which have a "variable" match from the highlights query within the offset range.
-pub fn identifiers_in_range(
+pub fn references_in_range(
range: Range<usize>,
range_text: &str,
reference_region: ReferenceRegion,
@@ -120,7 +120,7 @@ mod test {
use indoc::indoc;
use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
- use crate::reference::{ReferenceRegion, identifiers_in_range};
+ use crate::reference::{ReferenceRegion, references_in_range};
#[gpui::test]
fn test_identifier_node_truncated(cx: &mut TestAppContext) {
@@ -136,7 +136,7 @@ mod test {
let buffer = create_buffer(code, cx);
let range = 0..35;
- let references = identifiers_in_range(
+ let references = references_in_range(
range.clone(),
&code[range],
ReferenceRegion::Breadcrumb,
@@ -229,6 +229,27 @@ impl SyntaxIndex {
}
}
+ pub fn indexed_file_paths(&self, cx: &App) -> Task<Vec<ProjectPath>> {
+ let state = self.state.clone();
+ let project = self.project.clone();
+
+ cx.spawn(async move |cx| {
+ let state = state.lock().await;
+ let Some(project) = project.upgrade() else {
+ return vec![];
+ };
+ project
+ .read_with(cx, |project, cx| {
+ state
+ .files
+ .keys()
+ .filter_map(|entry_id| project.path_for_entry(*entry_id, cx))
+ .collect()
+ })
+ .unwrap_or_default()
+ })
+ }
+
fn handle_worktree_store_event(
&mut self,
_worktree_store: Entity<WorktreeStore>,
@@ -30,6 +30,7 @@ language_extension.workspace = true
language_model.workspace = true
language_models.workspace = true
languages = { workspace = true, features = ["load-grammars"] }
+log.workspace = true
node_runtime.workspace = true
paths.workspace = true
project.workspace = true
@@ -48,3 +49,4 @@ workspace-hack.workspace = true
zeta.workspace = true
zeta2.workspace = true
zlog.workspace = true
+ordered-float.workspace = true
@@ -3,19 +3,27 @@ mod headless;
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand};
use cloud_llm_client::predict_edits_v3;
-use edit_prediction_context::EditPredictionExcerptOptions;
+use edit_prediction_context::{
+ Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion,
+ SyntaxIndex, references_in_range,
+};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, Application, AsyncApp};
use gpui::{Entity, Task};
use language::Bias;
-use language::Buffer;
use language::Point;
+use language::{Buffer, OffsetRangeExt};
use language_model::LlmApiToken;
+use ordered_float::OrderedFloat;
use project::{Project, ProjectPath, Worktree};
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use serde_json::json;
+use std::cmp::Reverse;
+use std::collections::HashMap;
+use std::io::Write as _;
+use std::ops::Range;
use std::path::{Path, PathBuf};
use std::process::exit;
use std::str::FromStr;
@@ -23,6 +31,7 @@ use std::sync::Arc;
use std::time::Duration;
use util::paths::PathStyle;
use util::rel_path::RelPath;
+use util::{RangeExt, ResultExt as _};
use zeta::{PerformPredictEditsParams, Zeta};
use crate::headless::ZetaCliAppState;
@@ -49,6 +58,12 @@ enum Commands {
#[clap(flatten)]
context_args: Option<ContextArgs>,
},
+ RetrievalStats {
+ #[arg(long)]
+ worktree: PathBuf,
+ #[arg(long, default_value_t = 42)]
+ file_indexing_parallelism: usize,
+ },
}
#[derive(Debug, Args)]
@@ -316,6 +331,312 @@ async fn get_context(
}
}
+pub async fn retrieval_stats(
+ worktree: PathBuf,
+ file_indexing_parallelism: usize,
+ app_state: Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<String> {
+ let worktree_path = worktree.canonicalize()?;
+
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&worktree_path, true, cx)
+ })?
+ .await?;
+ let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
+
+ // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?;
+ index
+ .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
+ .await?;
+ let files = index
+ .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
+ .await;
+
+ let mut lsp_open_handles = Vec::new();
+ let mut output = std::fs::File::create("retrieval-stats.txt")?;
+ let mut results = Vec::new();
+ for (file_index, project_path) in files.iter().enumerate() {
+ println!(
+ "Processing file {} of {}: {}",
+ file_index + 1,
+ files.len(),
+ project_path.path.display(PathStyle::Posix)
+ );
+ let Some((lsp_open_handle, buffer)) =
+ open_buffer_with_language_server(&project, &worktree, &project_path.path, cx)
+ .await
+ .log_err()
+ else {
+ continue;
+ };
+ lsp_open_handles.push(lsp_open_handle);
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+ let full_range = 0..snapshot.len();
+ let references = references_in_range(
+ full_range,
+ &snapshot.text(),
+ ReferenceRegion::Nearby,
+ &snapshot,
+ );
+
+ let index = index.read_with(cx, |index, _cx| index.state().clone())?;
+ let index = index.lock().await;
+ for reference in references {
+ let query_point = snapshot.offset_to_point(reference.range.start);
+ let mut single_reference_map = HashMap::default();
+ single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
+ let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
+ query_point,
+ &snapshot,
+ &zeta2::DEFAULT_EXCERPT_OPTIONS,
+ Some(&index),
+ |_, _, _| single_reference_map,
+ );
+
+ let Some(edit_prediction_context) = edit_prediction_context else {
+ let result = RetrievalStatsResult {
+ identifier: reference.identifier,
+ point: query_point,
+ outcome: RetrievalStatsOutcome::NoExcerpt,
+ };
+ write!(output, "{:?}\n\n", result)?;
+ results.push(result);
+ continue;
+ };
+
+ let mut retrieved_definitions = Vec::new();
+ for scored_declaration in edit_prediction_context.declarations {
+ match &scored_declaration.declaration {
+ Declaration::File {
+ project_entry_id,
+ declaration,
+ } => {
+ let Some(path) = worktree.read_with(cx, |worktree, _cx| {
+ worktree
+ .entry_for_id(*project_entry_id)
+ .map(|entry| entry.path.clone())
+ })?
+ else {
+ log::error!("bug: file project entry not found");
+ continue;
+ };
+ let project_path = ProjectPath {
+ worktree_id,
+ path: path.clone(),
+ };
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+ .await?;
+ let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?;
+ retrieved_definitions.push((
+ path,
+ rope.offset_to_point(declaration.item_range.start)
+ ..rope.offset_to_point(declaration.item_range.end),
+ scored_declaration.scores.declaration,
+ scored_declaration.scores.retrieval,
+ ));
+ }
+ Declaration::Buffer {
+ project_entry_id,
+ rope,
+ declaration,
+ ..
+ } => {
+ let Some(path) = worktree.read_with(cx, |worktree, _cx| {
+ worktree
+ .entry_for_id(*project_entry_id)
+ .map(|entry| entry.path.clone())
+ })?
+ else {
+ log::error!("bug: buffer project entry not found");
+ continue;
+ };
+ retrieved_definitions.push((
+ path,
+ rope.offset_to_point(declaration.item_range.start)
+ ..rope.offset_to_point(declaration.item_range.end),
+ scored_declaration.scores.declaration,
+ scored_declaration.scores.retrieval,
+ ));
+ }
+ }
+ }
+ retrieved_definitions
+ .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score)));
+
+ // TODO: Consider still checking language server in this case, or having a mode for
+ // this. For now assuming that the purpose of this is to refine the ranking rather than
+ // refining whether the definition is present at all.
+ if retrieved_definitions.is_empty() {
+ continue;
+ }
+
+ // TODO: Rename declaration to definition in edit_prediction_context?
+ let lsp_result = project
+ .update(cx, |project, cx| {
+ project.definitions(&buffer, reference.range.start, cx)
+ })?
+ .await;
+ match lsp_result {
+ Ok(lsp_definitions) => {
+ let lsp_definitions = lsp_definitions
+ .unwrap_or_default()
+ .into_iter()
+ .filter_map(|definition| {
+ definition
+ .target
+ .buffer
+ .read_with(cx, |buffer, _cx| {
+ Some((
+ buffer.file()?.path().clone(),
+ definition.target.range.to_point(&buffer),
+ ))
+ })
+ .ok()?
+ })
+ .collect::<Vec<_>>();
+
+ let result = RetrievalStatsResult {
+ identifier: reference.identifier,
+ point: query_point,
+ outcome: RetrievalStatsOutcome::Success {
+ matches: lsp_definitions
+ .iter()
+ .map(|(path, range)| {
+ retrieved_definitions.iter().position(
+ |(retrieved_path, retrieved_range, _, _)| {
+ path == retrieved_path
+ && retrieved_range.contains_inclusive(&range)
+ },
+ )
+ })
+ .collect(),
+ lsp_definitions,
+ retrieved_definitions,
+ },
+ };
+ write!(output, "{:?}\n\n", result)?;
+ results.push(result);
+ }
+ Err(err) => {
+ let result = RetrievalStatsResult {
+ identifier: reference.identifier,
+ point: query_point,
+ outcome: RetrievalStatsOutcome::LanguageServerError {
+ message: err.to_string(),
+ },
+ };
+ write!(output, "{:?}\n\n", result)?;
+ results.push(result);
+ }
+ }
+ }
+ }
+
+ let mut no_excerpt_count = 0;
+ let mut error_count = 0;
+ let mut definitions_count = 0;
+ let mut top_match_count = 0;
+ let mut non_top_match_count = 0;
+ let mut ranking_involved_count = 0;
+ let mut ranking_involved_top_match_count = 0;
+ let mut ranking_involved_non_top_match_count = 0;
+ for result in &results {
+ match &result.outcome {
+ RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1,
+ RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1,
+ RetrievalStatsOutcome::Success {
+ matches,
+ retrieved_definitions,
+ ..
+ } => {
+ definitions_count += 1;
+ let top_matches = matches.contains(&Some(0));
+ if top_matches {
+ top_match_count += 1;
+ }
+ let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0));
+ if non_top_matches {
+ non_top_match_count += 1;
+ }
+ if retrieved_definitions.len() > 1 {
+ ranking_involved_count += 1;
+ if top_matches {
+ ranking_involved_top_match_count += 1;
+ }
+ if non_top_matches {
+ ranking_involved_non_top_match_count += 1;
+ }
+ }
+ }
+ }
+ }
+
+ println!("\nStats:\n");
+ println!("No Excerpt: {}", no_excerpt_count);
+ println!("Language Server Error: {}", error_count);
+ println!("Definitions: {}", definitions_count);
+ println!("Top Match: {}", top_match_count);
+ println!("Non-Top Match: {}", non_top_match_count);
+ println!("Ranking Involved: {}", ranking_involved_count);
+ println!(
+ "Ranking Involved Top Match: {}",
+ ranking_involved_top_match_count
+ );
+ println!(
+ "Ranking Involved Non-Top Match: {}",
+ ranking_involved_non_top_match_count
+ );
+
+ Ok("".to_string())
+}
+
+#[derive(Debug)]
+struct RetrievalStatsResult {
+ #[allow(dead_code)]
+ identifier: Identifier,
+ #[allow(dead_code)]
+ point: Point,
+ outcome: RetrievalStatsOutcome,
+}
+
+#[derive(Debug)]
+enum RetrievalStatsOutcome {
+ NoExcerpt,
+ LanguageServerError {
+ #[allow(dead_code)]
+ message: String,
+ },
+ Success {
+ matches: Vec<Option<usize>>,
+ #[allow(dead_code)]
+ lsp_definitions: Vec<(Arc<RelPath>, Range<Point>)>,
+ retrieved_definitions: Vec<(Arc<RelPath>, Range<Point>, f32, f32)>,
+ },
+}
+
pub async fn open_buffer(
project: &Entity<Project>,
worktree: &Entity<Worktree>,
@@ -385,6 +706,7 @@ pub fn wait_for_lang_server(
.unwrap()
.detach();
}
+ let (mut added_tx, mut added_rx) = mpsc::channel(1);
let subscriptions = [
cx.subscribe(&lsp_store, {
@@ -413,6 +735,7 @@ pub fn wait_for_lang_server(
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))
.detach();
+ added_tx.try_send(()).ok();
}
project::Event::DiskBasedDiagnosticsFinished { .. } => {
tx.try_send(()).ok();
@@ -423,6 +746,16 @@ pub fn wait_for_lang_server(
];
cx.spawn(async move |cx| {
+ if !has_lang_server {
+ // some buffers never have a language server, so this aborts quickly in that case.
+ let timeout = cx.background_executor().timer(Duration::from_secs(1));
+ futures::select! {
+ _ = added_rx.next() => {},
+ _ = timeout.fuse() => {
+ anyhow::bail!("Waiting for language server add timed out after 1 second");
+ }
+ };
+ }
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
let result = futures::select! {
_ = rx.next() => {
@@ -504,6 +837,10 @@ fn main() {
})
.await
}
+ Commands::RetrievalStats {
+ worktree,
+ file_indexing_parallelism,
+ } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await,
};
match result {
Ok(output) => {