From f2a6b57909db2165eb49714794d3919706ca6c7b Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Wed, 17 Sep 2025 01:55:06 -0600 Subject: [PATCH] Copy in experimental cli / declaration scoring code Co-authored-by: Oleksiy --- Cargo.lock | 6 + crates/edit_prediction_context/Cargo.toml | 10 + .../examples/zeta_context.rs | 289 ++++++++++++++++ .../src/edit_prediction_context.rs | 1 + .../src/scored_declaration.rs | 311 ++++++++++++++++++ 5 files changed, 617 insertions(+) create mode 100644 crates/edit_prediction_context/examples/zeta_context.rs create mode 100644 crates/edit_prediction_context/src/scored_declaration.rs diff --git a/Cargo.lock b/Cargo.lock index 43a2fc4041fbf76b57e62d335e94c695ef07fc12..330375cf465a863c5095806af3a7964e138ba25c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5140,17 +5140,23 @@ version = "0.1.0" dependencies = [ "anyhow", "arrayvec", + "clap", "collections", "futures 0.3.31", "gpui", "indoc", + "itertools 0.14.0", "language", "log", + "ordered-float 2.10.1", "pretty_assertions", "project", + "regex", + "serde", "serde_json", "settings", "slotmap", + "strum 0.27.1", "text", "tree-sitter", "util", diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index ad455b0a4ecb1746debafd23f0503b4365f9a0cf..29755a3b0eecdb3eacf5f8d4b0b0dba69135d4b5 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -11,6 +11,10 @@ workspace = true [lib] path = "src/edit_prediction_context.rs" +[[example]] +name = "zeta_context" +path = "examples/zeta_context.rs" + [dependencies] anyhow.workspace = true arrayvec.workspace = true @@ -19,17 +23,23 @@ gpui.workspace = true language.workspace = true log.workspace = true project.workspace = true +regex.workspace = true +serde.workspace = true slotmap.workspace = true +strum.workspace = true text.workspace = true tree-sitter.workspace = true util.workspace = true workspace-hack.workspace = true +itertools.workspace = true [dev-dependencies] +clap.workspace = true futures.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } +ordered-float.workspace = true pretty_assertions.workspace = true project = {workspace= true, features = ["test-support"]} serde_json.workspace = true diff --git a/crates/edit_prediction_context/examples/zeta_context.rs b/crates/edit_prediction_context/examples/zeta_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..7984ed4ba79250bfcc6eaca57c39a4f61f1cff13 --- /dev/null +++ b/crates/edit_prediction_context/examples/zeta_context.rs @@ -0,0 +1,289 @@ +use anyhow::{Result, anyhow}; +use clap::{Parser, Subcommand}; +use ordered_float::OrderedFloat; +use serde_json::json; +use std::fmt::Display; +use std::io::Write; +use std::path::Path; +use std::str::FromStr; +use std::{path::PathBuf, sync::Arc}; + +#[derive(Parser, Debug)] +#[command(name = "zeta_context")] +struct Args { + #[command(subcommand)] + command: Command, + #[arg(long, default_value_t = FileOrStdio::Stdio)] + log: FileOrStdio, +} + +#[derive(Subcommand, Debug)] +enum Command { + ShowIndex { + directory: PathBuf, + }, + NearbyReferences { + cursor_position: SourceLocation, + #[arg(long, default_value_t = 10)] + context_lines: u32, + }, + + Run { + directory: PathBuf, + cursor_position: CursorPosition, + #[arg(long, default_value_t = 2048)] + prompt_limit: usize, + #[arg(long)] + output_scores: Option, + #[command(flatten)] + excerpt_options: ExcerptOptions, + }, +} + +#[derive(Clone, Debug)] +enum CursorPosition { + Random, + Specific(SourceLocation), +} + +impl CursorPosition { + fn to_source_location_within( + &self, + languages: &[Arc], + directory: &Path, + ) -> SourceLocation { + match self { + CursorPosition::Random => { + let entries = ignore::Walk::new(directory) + .filter_map(|result| result.ok()) + .filter(|entry| language_for_file(languages, entry.path()).is_some()) + .collect::>(); + let selected_entry_ix = rand::random_range(0..entries.len()); + let path = entries[selected_entry_ix].path().to_path_buf(); + let source = std::fs::read_to_string(&path).unwrap(); + let offset = rand::random_range(0..source.len()); + let point = point_from_offset(&source, offset); + let source_location = SourceLocation { path, point }; + log::info!("Selected random cursor position: {source_location}"); + source_location + } + CursorPosition::Specific(location) => location.clone(), + } + } +} + +impl Display for CursorPosition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CursorPosition::Random => write!(f, "random"), + CursorPosition::Specific(location) => write!(f, "{}", &location), + } + } +} + +impl FromStr for CursorPosition { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "random" => Ok(CursorPosition::Random), + _ => Ok(CursorPosition::Specific(SourceLocation::from_str(s)?)), + } + } +} + +#[derive(Debug, Clone)] +enum FileOrStdio { + File(PathBuf), + Stdio, +} + +impl FileOrStdio { + #[allow(dead_code)] + fn read_to_string(&self) -> Result { + match self { + FileOrStdio::File(path) => std::fs::read_to_string(path), + FileOrStdio::Stdio => std::io::read_to_string(std::io::stdin()), + } + } + + fn write_file_or_stdout(&self) -> Result, std::io::Error> { + match self { + FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)), + FileOrStdio::Stdio => Ok(Box::new(std::io::stdout())), + } + } + + fn write_file_or_stderr( + &self, + ) -> Result, std::io::Error> { + match self { + FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)), + FileOrStdio::Stdio => Ok(Box::new(std::io::stderr())), + } + } +} + +impl Display for FileOrStdio { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FileOrStdio::File(path) => write!(f, "{}", path.display()), + FileOrStdio::Stdio => write!(f, "-"), + } + } +} + +impl FromStr for FileOrStdio { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + match s { + "-" => Ok(Self::Stdio), + _ => Ok(Self::File(PathBuf::from_str(s)?)), + } + } +} + +fn main() -> Result<()> { + let args = ZetaContextArgs::parse(); + env_logger::Builder::from_default_env() + .target(env_logger::Target::Pipe(args.log.write_file_or_stderr()?)) + .init(); + let languages = load_languages(); + match &args.command { + Command::ShowIndex { directory } => { + /* + let directory = directory.canonicalize()?; + let index = IdentifierIndex::index_path(&languages, &directory)?; + for ((identifier, language_name), files) in &index.identifier_to_definitions { + println!("\n{} ({})", identifier.0, language_name.0); + for (file, definitions) in files { + println!(" {:?}", file); + for definition in definitions { + println!(" {}", definition.path_string(&index)); + } + } + } + */ + Ok(()) + } + + Command::NearbyReferences { + cursor_position, + context_lines, + } => { + /* + let (language, source, tree) = parse_file(&languages, &cursor_position.path)?; + let start_offset = offset_from_point( + &source, + Point::new(cursor_position.point.row.saturating_sub(*context_lines), 0), + ); + let end_offset = offset_from_point( + &source, + Point::new(cursor_position.point.row + context_lines, 0), + ); + let references = local_identifiers( + ReferenceRegion::Nearby, + &language, + &tree, + &source, + start_offset..end_offset, + ); + for reference in references { + println!( + "{:?} {}", + point_range_from_offset_range(&source, reference.range), + reference.identifier.0, + ); + } + */ + Ok(()) + } + + Command::Run { + directory, + cursor_position, + prompt_limit, + output_scores, + excerpt_options, + } => { + let directory = directory.canonicalize()?; + let index = IdentifierIndex::index_path(&languages, &directory)?; + let cursor_position = cursor_position.to_source_location_within(&languages, &directory); + let excerpt_file: Arc = cursor_position.path.as_path().into(); + let (language, source, tree) = parse_file(&languages, &excerpt_file)?; + let cursor_offset = offset_from_point(&source, cursor_position.point); + let Some(excerpt_ranges) = ExcerptRangesInput { + language: &language, + tree: &tree, + source: &source, + cursor_offset, + options: excerpt_options, + } + .select() else { + return Err(anyhow!("line containing cursor does not fit within window")); + }; + let mut snippets = gather_snippets( + &language, + &index, + &tree, + &excerpt_file, + &source, + excerpt_ranges.clone(), + cursor_offset, + ); + let planned_prompt = PromptPlanner::populate( + &index, + snippets.clone(), + excerpt_file, + excerpt_ranges.clone(), + cursor_offset, + *prompt_limit, + &directory, + ); + let prompt_string = planned_prompt.to_prompt_string(&index); + println!("{}", &prompt_string); + + if let Some(output_scores) = output_scores { + snippets.sort_by_key(|snippet| OrderedFloat(-snippet.scores.signature)); + let writer = output_scores.write_file_or_stdout()?; + serde_json::to_writer_pretty( + writer, + &snippets + .into_iter() + .map(|snippet| { + json!({ + "file": snippet.definition_file, + "symbol_path": snippet.definition.path_string(&index), + "signature_score": snippet.scores.signature, + "definition_score": snippet.scores.definition, + "signature_score_density": snippet.score_density(&index, SnippetStyle::Signature), + "definition_score_density": snippet.score_density(&index, SnippetStyle::Definition), + "score_components": snippet.score_components + }) + }) + .collect::>(), + )?; + } + + let actual_window_size = range_size(excerpt_ranges.excerpt_range); + if actual_window_size > excerpt_options.window_max_bytes { + let exceeded_amount = actual_window_size - excerpt_options.window_max_bytes; + if exceeded_amount as f64 / excerpt_options.window_max_bytes as f64 > 0.05 { + log::error!("Exceeded max main excerpt size by {exceeded_amount} bytes"); + } + } + + if prompt_string.len() > *prompt_limit { + let exceeded_amount = prompt_string.len() - *prompt_limit; + if exceeded_amount as f64 / *prompt_limit as f64 > 0.1 { + log::error!( + "Exceeded max prompt size of {prompt_limit} bytes by {exceeded_amount} bytes" + ); + } + } + + Ok(()) + } + } +} diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index acb329c99ee2d9fe1ad39f662ef6618cd6194ef1..03102cc62b01888725dfdc73a40897c15cecb46b 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -1,6 +1,7 @@ mod excerpt; mod outline; mod reference; +mod scored_declaration; mod text_similarity; mod tree_sitter_index; diff --git a/crates/edit_prediction_context/src/scored_declaration.rs b/crates/edit_prediction_context/src/scored_declaration.rs new file mode 100644 index 0000000000000000000000000000000000000000..9c3394db1de6e521c2cee9bbfacc234ec7520568 --- /dev/null +++ b/crates/edit_prediction_context/src/scored_declaration.rs @@ -0,0 +1,311 @@ +use itertools::Itertools as _; +use serde::Serialize; +use std::collections::HashMap; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; +use strum::EnumIter; +use tree_sitter::{QueryCursor, StreamingIterator, Tree}; + +use crate::{Declaration, outline::Identifier}; + +#[derive(Clone, Debug)] +pub struct ScoredSnippet { + #[allow(dead_code)] + pub identifier: Identifier, + pub definition_file: Arc, + pub definition: OutlineItem, + pub score_components: ScoreInputs, + pub scores: Scores, +} + +// TODO: Consider having "Concise" style corresponding to `concise_text` +#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum SnippetStyle { + Signature, + Definition, +} + +impl ScoredSnippet { + /// Returns the score for this snippet with the specified style. + pub fn score(&self, style: SnippetStyle) -> f32 { + match style { + SnippetStyle::Signature => self.scores.signature, + SnippetStyle::Definition => self.scores.definition, + } + } + + /// Returns the byte range for the snippet with the specified style. For `Signature` this is the + /// signature_range expanded to line boundaries. For `Definition` this is the item_range expanded to + /// line boundaries (similar to slice_at_line_boundaries). + pub fn line_range( + &self, + identifier_index: &IdentifierIndex, + style: SnippetStyle, + ) -> Range { + let source = identifier_index + .path_to_source + .get(&self.definition_file) + .unwrap(); + + let base_range = match style { + SnippetStyle::Signature => self.definition.signature_range.clone(), + SnippetStyle::Definition => self.definition.item_range.clone(), + }; + + expand_range_to_line_boundaries(source, base_range) + } + + pub fn score_density(&self, identifier_index: &IdentifierIndex, style: SnippetStyle) -> f32 { + self.score(style) / range_size(self.line_range(identifier_index, style)) as f32 + } +} + +fn scored_snippets( + language: &Language, + index: &IdentifierIndex, + source: &str, + reference_file: &Path, + references: Vec, + cursor_offset: usize, + excerpt_range: Range, +) -> Vec { + let cursor = point_from_offset(source, cursor_offset); + + let containing_range_identifier_occurrences = + IdentifierOccurrences::within_string(&source[excerpt_range.clone()]); + + let start_point = Point::new(cursor.row.saturating_sub(2), 0); + let end_point = Point::new(cursor.row + 1, 0); + let adjacent_identifier_occurrences = IdentifierOccurrences::within_string( + &source[offset_from_point(source, start_point)..offset_from_point(source, end_point)], + ); + + let mut identifier_to_references: HashMap> = HashMap::new(); + for reference in references { + identifier_to_references + .entry(reference.identifier.clone()) + .or_insert_with(Vec::new) + .push(reference); + } + + identifier_to_references + .into_iter() + .flat_map(|(identifier, references)| { + let Some(definitions) = index + .identifier_to_definitions + .get(&(identifier.clone(), language.name.clone())) + else { + return Vec::new(); + }; + let definition_count = definitions.len(); + let definition_file_count = definitions.keys().len(); + + definitions + .iter_all() + .flat_map(|(definition_file, file_definitions)| { + let same_file_definition_count = file_definitions.len(); + let is_same_file = reference_file == definition_file.as_ref(); + file_definitions + .iter() + .filter(|definition| { + !is_same_file + || !range_intersection(&definition.item_range, &excerpt_range) + .is_some() + }) + .filter_map(|definition| { + let definition_line_distance = if is_same_file { + let definition_line = + point_from_offset(source, definition.item_range.start).row; + (cursor.row as i32 - definition_line as i32).abs() as u32 + } else { + 0 + }; + Some((definition_line_distance, definition)) + }) + .sorted_by_key(|&(distance, _)| distance) + .enumerate() + .map( + |( + definition_line_distance_rank, + (definition_line_distance, definition), + )| { + score_snippet( + index, + source, + &identifier, + &references, + definition_file.clone(), + definition.clone(), + is_same_file, + definition_line_distance, + definition_line_distance_rank, + same_file_definition_count, + definition_count, + definition_file_count, + &containing_range_identifier_occurrences, + &adjacent_identifier_occurrences, + cursor, + ) + }, + ) + .collect::>() + }) + .collect::>() + }) + .flatten() + .collect::>() +} + +fn score_snippet( + index: &IdentifierIndex, + reference_source: &str, + identifier: &Identifier, + references: &Vec, + definition_file: Arc, + definition: OutlineItem, + is_same_file: bool, + definition_line_distance: u32, + definition_line_distance_rank: usize, + same_file_definition_count: usize, + definition_count: usize, + definition_file_count: usize, + containing_range_identifier_occurrences: &IdentifierOccurrences, + adjacent_identifier_occurrences: &IdentifierOccurrences, + cursor: Point, +) -> Option { + let is_referenced_nearby = references + .iter() + .any(|r| r.reference_region == ReferenceRegion::Nearby); + let is_referenced_in_breadcrumb = references + .iter() + .any(|r| r.reference_region == ReferenceRegion::Breadcrumb); + let reference_count = references.len(); + let reference_line_distance = references + .iter() + .map(|r| { + let reference_line = point_from_offset(reference_source, r.range.start).row as i32; + (cursor.row as i32 - reference_line).abs() as u32 + }) + .min() + .unwrap(); + + let definition_source = index.path_to_source.get(&definition_file).unwrap(); + let item_source_occurrences = + IdentifierOccurrences::within_string(definition.item(&definition_source)); + let item_signature_occurrences = + IdentifierOccurrences::within_string(definition.signature(&definition_source)); + let containing_range_vs_item_jaccard = jaccard_similarity( + containing_range_identifier_occurrences, + &item_source_occurrences, + ); + let containing_range_vs_signature_jaccard = jaccard_similarity( + containing_range_identifier_occurrences, + &item_signature_occurrences, + ); + let adjacent_vs_item_jaccard = + jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences); + let adjacent_vs_signature_jaccard = + jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences); + + let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient( + containing_range_identifier_occurrences, + &item_source_occurrences, + ); + let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient( + containing_range_identifier_occurrences, + &item_signature_occurrences, + ); + let adjacent_vs_item_weighted_overlap = + weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences); + let adjacent_vs_signature_weighted_overlap = + weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences); + + let score_components = ScoreInputs { + is_same_file, + is_referenced_nearby, + is_referenced_in_breadcrumb, + reference_line_distance, + definition_line_distance, + definition_line_distance_rank, + reference_count, + same_file_definition_count, + definition_count, + definition_file_count, + containing_range_vs_item_jaccard, + containing_range_vs_signature_jaccard, + adjacent_vs_item_jaccard, + adjacent_vs_signature_jaccard, + containing_range_vs_item_weighted_overlap, + containing_range_vs_signature_weighted_overlap, + adjacent_vs_item_weighted_overlap, + adjacent_vs_signature_weighted_overlap, + }; + + Some(ScoredSnippet { + identifier: identifier.clone(), + definition_file, + definition, + scores: score_components.score(), + score_components, + }) +} + +#[derive(Clone, Debug, Serialize)] +pub struct ScoreInputs { + pub is_same_file: bool, + pub is_referenced_nearby: bool, + pub is_referenced_in_breadcrumb: bool, + pub reference_count: usize, + pub same_file_definition_count: usize, + pub definition_count: usize, + pub definition_file_count: usize, + pub reference_line_distance: u32, + pub definition_line_distance: u32, + pub definition_line_distance_rank: usize, + pub containing_range_vs_item_jaccard: f32, + pub containing_range_vs_signature_jaccard: f32, + pub adjacent_vs_item_jaccard: f32, + pub adjacent_vs_signature_jaccard: f32, + pub containing_range_vs_item_weighted_overlap: f32, + pub containing_range_vs_signature_weighted_overlap: f32, + pub adjacent_vs_item_weighted_overlap: f32, + pub adjacent_vs_signature_weighted_overlap: f32, +} + +#[derive(Clone, Debug, Serialize)] +pub struct Scores { + pub signature: f32, + pub definition: f32, +} + +impl ScoreInputs { + fn score(&self) -> Scores { + // Score related to how likely this is the correct definition, range 0 to 1 + let accuracy_score = if self.is_same_file { + // TODO: use definition_line_distance_rank + (0.5 / self.same_file_definition_count as f32) + + (0.5 / self.definition_file_count as f32) + } else { + 1.0 / self.definition_count as f32 + }; + + // Score related to the distance between the reference and cursor, range 0 to 1 + let distance_score = if self.is_referenced_nearby { + 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0) + } else { + // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures + 0.5 + }; + + // For now instead of linear combination, the scores are just multiplied together. + let combined_score = 10.0 * accuracy_score * distance_score; + + Scores { + signature: combined_score * self.containing_range_vs_signature_weighted_overlap, + // definition score gets boosted both by being multipled by 2 and by there being more + // weighted overlap. + definition: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap, + } + } +}