Detailed changes
@@ -5140,17 +5140,23 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
+ "clap",
"collections",
"futures 0.3.31",
"gpui",
"indoc",
+ "itertools 0.14.0",
"language",
"log",
+ "ordered-float 2.10.1",
"pretty_assertions",
"project",
+ "regex",
+ "serde",
"serde_json",
"settings",
"slotmap",
+ "strum 0.27.1",
"text",
"tree-sitter",
"util",
@@ -11,6 +11,10 @@ workspace = true
[lib]
path = "src/edit_prediction_context.rs"
+[[example]]
+name = "zeta_context"
+path = "examples/zeta_context.rs"
+
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
@@ -19,17 +23,23 @@ gpui.workspace = true
language.workspace = true
log.workspace = true
project.workspace = true
+regex.workspace = true
+serde.workspace = true
slotmap.workspace = true
+strum.workspace = true
text.workspace = true
tree-sitter.workspace = true
util.workspace = true
workspace-hack.workspace = true
+itertools.workspace = true
[dev-dependencies]
+clap.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
+ordered-float.workspace = true
pretty_assertions.workspace = true
project = {workspace= true, features = ["test-support"]}
serde_json.workspace = true
@@ -0,0 +1,289 @@
+use anyhow::{Result, anyhow};
+use clap::{Parser, Subcommand};
+use ordered_float::OrderedFloat;
+use serde_json::json;
+use std::fmt::Display;
+use std::io::Write;
+use std::path::Path;
+use std::str::FromStr;
+use std::{path::PathBuf, sync::Arc};
+
+#[derive(Parser, Debug)]
+#[command(name = "zeta_context")]
+struct Args {
+ #[command(subcommand)]
+ command: Command,
+ #[arg(long, default_value_t = FileOrStdio::Stdio)]
+ log: FileOrStdio,
+}
+
+#[derive(Subcommand, Debug)]
+enum Command {
+ ShowIndex {
+ directory: PathBuf,
+ },
+ NearbyReferences {
+ cursor_position: SourceLocation,
+ #[arg(long, default_value_t = 10)]
+ context_lines: u32,
+ },
+
+ Run {
+ directory: PathBuf,
+ cursor_position: CursorPosition,
+ #[arg(long, default_value_t = 2048)]
+ prompt_limit: usize,
+ #[arg(long)]
+ output_scores: Option<FileOrStdio>,
+ #[command(flatten)]
+ excerpt_options: ExcerptOptions,
+ },
+}
+
+#[derive(Clone, Debug)]
+enum CursorPosition {
+ Random,
+ Specific(SourceLocation),
+}
+
+impl CursorPosition {
+ fn to_source_location_within(
+ &self,
+ languages: &[Arc<Language>],
+ directory: &Path,
+ ) -> SourceLocation {
+ match self {
+ CursorPosition::Random => {
+ let entries = ignore::Walk::new(directory)
+ .filter_map(|result| result.ok())
+ .filter(|entry| language_for_file(languages, entry.path()).is_some())
+ .collect::<Vec<_>>();
+ let selected_entry_ix = rand::random_range(0..entries.len());
+ let path = entries[selected_entry_ix].path().to_path_buf();
+ let source = std::fs::read_to_string(&path).unwrap();
+ let offset = rand::random_range(0..source.len());
+ let point = point_from_offset(&source, offset);
+ let source_location = SourceLocation { path, point };
+ log::info!("Selected random cursor position: {source_location}");
+ source_location
+ }
+ CursorPosition::Specific(location) => location.clone(),
+ }
+ }
+}
+
+impl Display for CursorPosition {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ CursorPosition::Random => write!(f, "random"),
+ CursorPosition::Specific(location) => write!(f, "{}", &location),
+ }
+ }
+}
+
+impl FromStr for CursorPosition {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "random" => Ok(CursorPosition::Random),
+ _ => Ok(CursorPosition::Specific(SourceLocation::from_str(s)?)),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+enum FileOrStdio {
+ File(PathBuf),
+ Stdio,
+}
+
+impl FileOrStdio {
+ #[allow(dead_code)]
+ fn read_to_string(&self) -> Result<String, std::io::Error> {
+ match self {
+ FileOrStdio::File(path) => std::fs::read_to_string(path),
+ FileOrStdio::Stdio => std::io::read_to_string(std::io::stdin()),
+ }
+ }
+
+ fn write_file_or_stdout(&self) -> Result<Box<dyn Write + Send + 'static>, std::io::Error> {
+ match self {
+ FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)),
+ FileOrStdio::Stdio => Ok(Box::new(std::io::stdout())),
+ }
+ }
+
+ fn write_file_or_stderr(
+ &self,
+ ) -> Result<Box<dyn std::io::Write + Send + 'static>, std::io::Error> {
+ match self {
+ FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)),
+ FileOrStdio::Stdio => Ok(Box::new(std::io::stderr())),
+ }
+ }
+}
+
+impl Display for FileOrStdio {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ FileOrStdio::File(path) => write!(f, "{}", path.display()),
+ FileOrStdio::Stdio => write!(f, "-"),
+ }
+ }
+}
+
+impl FromStr for FileOrStdio {
+ type Err = <PathBuf as FromStr>::Err;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "-" => Ok(Self::Stdio),
+ _ => Ok(Self::File(PathBuf::from_str(s)?)),
+ }
+ }
+}
+
+fn main() -> Result<()> {
+ let args = ZetaContextArgs::parse();
+ env_logger::Builder::from_default_env()
+ .target(env_logger::Target::Pipe(args.log.write_file_or_stderr()?))
+ .init();
+ let languages = load_languages();
+ match &args.command {
+ Command::ShowIndex { directory } => {
+ /*
+ let directory = directory.canonicalize()?;
+ let index = IdentifierIndex::index_path(&languages, &directory)?;
+ for ((identifier, language_name), files) in &index.identifier_to_definitions {
+ println!("\n{} ({})", identifier.0, language_name.0);
+ for (file, definitions) in files {
+ println!(" {:?}", file);
+ for definition in definitions {
+ println!(" {}", definition.path_string(&index));
+ }
+ }
+ }
+ */
+ Ok(())
+ }
+
+ Command::NearbyReferences {
+ cursor_position,
+ context_lines,
+ } => {
+ /*
+ let (language, source, tree) = parse_file(&languages, &cursor_position.path)?;
+ let start_offset = offset_from_point(
+ &source,
+ Point::new(cursor_position.point.row.saturating_sub(*context_lines), 0),
+ );
+ let end_offset = offset_from_point(
+ &source,
+ Point::new(cursor_position.point.row + context_lines, 0),
+ );
+ let references = local_identifiers(
+ ReferenceRegion::Nearby,
+ &language,
+ &tree,
+ &source,
+ start_offset..end_offset,
+ );
+ for reference in references {
+ println!(
+ "{:?} {}",
+ point_range_from_offset_range(&source, reference.range),
+ reference.identifier.0,
+ );
+ }
+ */
+ Ok(())
+ }
+
+ Command::Run {
+ directory,
+ cursor_position,
+ prompt_limit,
+ output_scores,
+ excerpt_options,
+ } => {
+ let directory = directory.canonicalize()?;
+ let index = IdentifierIndex::index_path(&languages, &directory)?;
+ let cursor_position = cursor_position.to_source_location_within(&languages, &directory);
+ let excerpt_file: Arc<Path> = cursor_position.path.as_path().into();
+ let (language, source, tree) = parse_file(&languages, &excerpt_file)?;
+ let cursor_offset = offset_from_point(&source, cursor_position.point);
+ let Some(excerpt_ranges) = ExcerptRangesInput {
+ language: &language,
+ tree: &tree,
+ source: &source,
+ cursor_offset,
+ options: excerpt_options,
+ }
+ .select() else {
+ return Err(anyhow!("line containing cursor does not fit within window"));
+ };
+ let mut snippets = gather_snippets(
+ &language,
+ &index,
+ &tree,
+ &excerpt_file,
+ &source,
+ excerpt_ranges.clone(),
+ cursor_offset,
+ );
+ let planned_prompt = PromptPlanner::populate(
+ &index,
+ snippets.clone(),
+ excerpt_file,
+ excerpt_ranges.clone(),
+ cursor_offset,
+ *prompt_limit,
+ &directory,
+ );
+ let prompt_string = planned_prompt.to_prompt_string(&index);
+ println!("{}", &prompt_string);
+
+ if let Some(output_scores) = output_scores {
+ snippets.sort_by_key(|snippet| OrderedFloat(-snippet.scores.signature));
+ let writer = output_scores.write_file_or_stdout()?;
+ serde_json::to_writer_pretty(
+ writer,
+ &snippets
+ .into_iter()
+ .map(|snippet| {
+ json!({
+ "file": snippet.definition_file,
+ "symbol_path": snippet.definition.path_string(&index),
+ "signature_score": snippet.scores.signature,
+ "definition_score": snippet.scores.definition,
+ "signature_score_density": snippet.score_density(&index, SnippetStyle::Signature),
+ "definition_score_density": snippet.score_density(&index, SnippetStyle::Definition),
+ "score_components": snippet.score_components
+ })
+ })
+ .collect::<Vec<_>>(),
+ )?;
+ }
+
+ let actual_window_size = range_size(excerpt_ranges.excerpt_range);
+ if actual_window_size > excerpt_options.window_max_bytes {
+ let exceeded_amount = actual_window_size - excerpt_options.window_max_bytes;
+ if exceeded_amount as f64 / excerpt_options.window_max_bytes as f64 > 0.05 {
+ log::error!("Exceeded max main excerpt size by {exceeded_amount} bytes");
+ }
+ }
+
+ if prompt_string.len() > *prompt_limit {
+ let exceeded_amount = prompt_string.len() - *prompt_limit;
+ if exceeded_amount as f64 / *prompt_limit as f64 > 0.1 {
+ log::error!(
+ "Exceeded max prompt size of {prompt_limit} bytes by {exceeded_amount} bytes"
+ );
+ }
+ }
+
+ Ok(())
+ }
+ }
+}
@@ -1,6 +1,7 @@
mod excerpt;
mod outline;
mod reference;
+mod scored_declaration;
mod text_similarity;
mod tree_sitter_index;
@@ -0,0 +1,311 @@
+use itertools::Itertools as _;
+use serde::Serialize;
+use std::collections::HashMap;
+use std::ops::Range;
+use std::path::Path;
+use std::sync::Arc;
+use strum::EnumIter;
+use tree_sitter::{QueryCursor, StreamingIterator, Tree};
+
+use crate::{Declaration, outline::Identifier};
+
+#[derive(Clone, Debug)]
+pub struct ScoredSnippet {
+ #[allow(dead_code)]
+ pub identifier: Identifier,
+ pub definition_file: Arc<Path>,
+ pub definition: OutlineItem,
+ pub score_components: ScoreInputs,
+ pub scores: Scores,
+}
+
+// TODO: Consider having "Concise" style corresponding to `concise_text`
+#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub enum SnippetStyle {
+ Signature,
+ Definition,
+}
+
+impl ScoredSnippet {
+ /// Returns the score for this snippet with the specified style.
+ pub fn score(&self, style: SnippetStyle) -> f32 {
+ match style {
+ SnippetStyle::Signature => self.scores.signature,
+ SnippetStyle::Definition => self.scores.definition,
+ }
+ }
+
+ /// Returns the byte range for the snippet with the specified style. For `Signature` this is the
+ /// signature_range expanded to line boundaries. For `Definition` this is the item_range expanded to
+ /// line boundaries (similar to slice_at_line_boundaries).
+ pub fn line_range(
+ &self,
+ identifier_index: &IdentifierIndex,
+ style: SnippetStyle,
+ ) -> Range<usize> {
+ let source = identifier_index
+ .path_to_source
+ .get(&self.definition_file)
+ .unwrap();
+
+ let base_range = match style {
+ SnippetStyle::Signature => self.definition.signature_range.clone(),
+ SnippetStyle::Definition => self.definition.item_range.clone(),
+ };
+
+ expand_range_to_line_boundaries(source, base_range)
+ }
+
+ pub fn score_density(&self, identifier_index: &IdentifierIndex, style: SnippetStyle) -> f32 {
+ self.score(style) / range_size(self.line_range(identifier_index, style)) as f32
+ }
+}
+
+fn scored_snippets(
+ language: &Language,
+ index: &IdentifierIndex,
+ source: &str,
+ reference_file: &Path,
+ references: Vec<Reference>,
+ cursor_offset: usize,
+ excerpt_range: Range<usize>,
+) -> Vec<ScoredSnippet> {
+ let cursor = point_from_offset(source, cursor_offset);
+
+ let containing_range_identifier_occurrences =
+ IdentifierOccurrences::within_string(&source[excerpt_range.clone()]);
+
+ let start_point = Point::new(cursor.row.saturating_sub(2), 0);
+ let end_point = Point::new(cursor.row + 1, 0);
+ let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
+ &source[offset_from_point(source, start_point)..offset_from_point(source, end_point)],
+ );
+
+ let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
+ for reference in references {
+ identifier_to_references
+ .entry(reference.identifier.clone())
+ .or_insert_with(Vec::new)
+ .push(reference);
+ }
+
+ identifier_to_references
+ .into_iter()
+ .flat_map(|(identifier, references)| {
+ let Some(definitions) = index
+ .identifier_to_definitions
+ .get(&(identifier.clone(), language.name.clone()))
+ else {
+ return Vec::new();
+ };
+ let definition_count = definitions.len();
+ let definition_file_count = definitions.keys().len();
+
+ definitions
+ .iter_all()
+ .flat_map(|(definition_file, file_definitions)| {
+ let same_file_definition_count = file_definitions.len();
+ let is_same_file = reference_file == definition_file.as_ref();
+ file_definitions
+ .iter()
+ .filter(|definition| {
+ !is_same_file
+ || !range_intersection(&definition.item_range, &excerpt_range)
+ .is_some()
+ })
+ .filter_map(|definition| {
+ let definition_line_distance = if is_same_file {
+ let definition_line =
+ point_from_offset(source, definition.item_range.start).row;
+ (cursor.row as i32 - definition_line as i32).abs() as u32
+ } else {
+ 0
+ };
+ Some((definition_line_distance, definition))
+ })
+ .sorted_by_key(|&(distance, _)| distance)
+ .enumerate()
+ .map(
+ |(
+ definition_line_distance_rank,
+ (definition_line_distance, definition),
+ )| {
+ score_snippet(
+ index,
+ source,
+ &identifier,
+ &references,
+ definition_file.clone(),
+ definition.clone(),
+ is_same_file,
+ definition_line_distance,
+ definition_line_distance_rank,
+ same_file_definition_count,
+ definition_count,
+ definition_file_count,
+ &containing_range_identifier_occurrences,
+ &adjacent_identifier_occurrences,
+ cursor,
+ )
+ },
+ )
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>()
+ })
+ .flatten()
+ .collect::<Vec<_>>()
+}
+
+fn score_snippet(
+ index: &IdentifierIndex,
+ reference_source: &str,
+ identifier: &Identifier,
+ references: &Vec<Reference>,
+ definition_file: Arc<Path>,
+ definition: OutlineItem,
+ is_same_file: bool,
+ definition_line_distance: u32,
+ definition_line_distance_rank: usize,
+ same_file_definition_count: usize,
+ definition_count: usize,
+ definition_file_count: usize,
+ containing_range_identifier_occurrences: &IdentifierOccurrences,
+ adjacent_identifier_occurrences: &IdentifierOccurrences,
+ cursor: Point,
+) -> Option<ScoredSnippet> {
+ let is_referenced_nearby = references
+ .iter()
+ .any(|r| r.reference_region == ReferenceRegion::Nearby);
+ let is_referenced_in_breadcrumb = references
+ .iter()
+ .any(|r| r.reference_region == ReferenceRegion::Breadcrumb);
+ let reference_count = references.len();
+ let reference_line_distance = references
+ .iter()
+ .map(|r| {
+ let reference_line = point_from_offset(reference_source, r.range.start).row as i32;
+ (cursor.row as i32 - reference_line).abs() as u32
+ })
+ .min()
+ .unwrap();
+
+ let definition_source = index.path_to_source.get(&definition_file).unwrap();
+ let item_source_occurrences =
+ IdentifierOccurrences::within_string(definition.item(&definition_source));
+ let item_signature_occurrences =
+ IdentifierOccurrences::within_string(definition.signature(&definition_source));
+ let containing_range_vs_item_jaccard = jaccard_similarity(
+ containing_range_identifier_occurrences,
+ &item_source_occurrences,
+ );
+ let containing_range_vs_signature_jaccard = jaccard_similarity(
+ containing_range_identifier_occurrences,
+ &item_signature_occurrences,
+ );
+ let adjacent_vs_item_jaccard =
+ jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
+ let adjacent_vs_signature_jaccard =
+ jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+ let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
+ containing_range_identifier_occurrences,
+ &item_source_occurrences,
+ );
+ let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
+ containing_range_identifier_occurrences,
+ &item_signature_occurrences,
+ );
+ let adjacent_vs_item_weighted_overlap =
+ weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
+ let adjacent_vs_signature_weighted_overlap =
+ weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+ let score_components = ScoreInputs {
+ is_same_file,
+ is_referenced_nearby,
+ is_referenced_in_breadcrumb,
+ reference_line_distance,
+ definition_line_distance,
+ definition_line_distance_rank,
+ reference_count,
+ same_file_definition_count,
+ definition_count,
+ definition_file_count,
+ containing_range_vs_item_jaccard,
+ containing_range_vs_signature_jaccard,
+ adjacent_vs_item_jaccard,
+ adjacent_vs_signature_jaccard,
+ containing_range_vs_item_weighted_overlap,
+ containing_range_vs_signature_weighted_overlap,
+ adjacent_vs_item_weighted_overlap,
+ adjacent_vs_signature_weighted_overlap,
+ };
+
+ Some(ScoredSnippet {
+ identifier: identifier.clone(),
+ definition_file,
+ definition,
+ scores: score_components.score(),
+ score_components,
+ })
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct ScoreInputs {
+ pub is_same_file: bool,
+ pub is_referenced_nearby: bool,
+ pub is_referenced_in_breadcrumb: bool,
+ pub reference_count: usize,
+ pub same_file_definition_count: usize,
+ pub definition_count: usize,
+ pub definition_file_count: usize,
+ pub reference_line_distance: u32,
+ pub definition_line_distance: u32,
+ pub definition_line_distance_rank: usize,
+ pub containing_range_vs_item_jaccard: f32,
+ pub containing_range_vs_signature_jaccard: f32,
+ pub adjacent_vs_item_jaccard: f32,
+ pub adjacent_vs_signature_jaccard: f32,
+ pub containing_range_vs_item_weighted_overlap: f32,
+ pub containing_range_vs_signature_weighted_overlap: f32,
+ pub adjacent_vs_item_weighted_overlap: f32,
+ pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct Scores {
+ pub signature: f32,
+ pub definition: f32,
+}
+
+impl ScoreInputs {
+ fn score(&self) -> Scores {
+ // Score related to how likely this is the correct definition, range 0 to 1
+ let accuracy_score = if self.is_same_file {
+ // TODO: use definition_line_distance_rank
+ (0.5 / self.same_file_definition_count as f32)
+ + (0.5 / self.definition_file_count as f32)
+ } else {
+ 1.0 / self.definition_count as f32
+ };
+
+ // Score related to the distance between the reference and cursor, range 0 to 1
+ let distance_score = if self.is_referenced_nearby {
+ 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+ } else {
+ // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
+ 0.5
+ };
+
+ // For now instead of linear combination, the scores are just multiplied together.
+ let combined_score = 10.0 * accuracy_score * distance_score;
+
+ Scores {
+ signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+ // definition score gets boosted both by being multipled by 2 and by there being more
+ // weighted overlap.
+ definition: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+ }
+ }
+}