Copy in experimental cli / declaration scoring code

Michael Sloan and Oleksiy created

Co-authored-by: Oleksiy <oleksiy@zed.dev>

Change summary

Cargo.lock                                                    |   6 
crates/edit_prediction_context/Cargo.toml                     |  10 
crates/edit_prediction_context/examples/zeta_context.rs       | 289 ++++
crates/edit_prediction_context/src/edit_prediction_context.rs |   1 
crates/edit_prediction_context/src/scored_declaration.rs      | 311 +++++
5 files changed, 617 insertions(+)

Detailed changes

Cargo.lock 🔗

@@ -5140,17 +5140,23 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "arrayvec",
+ "clap",
  "collections",
  "futures 0.3.31",
  "gpui",
  "indoc",
+ "itertools 0.14.0",
  "language",
  "log",
+ "ordered-float 2.10.1",
  "pretty_assertions",
  "project",
+ "regex",
+ "serde",
  "serde_json",
  "settings",
  "slotmap",
+ "strum 0.27.1",
  "text",
  "tree-sitter",
  "util",

crates/edit_prediction_context/Cargo.toml 🔗

@@ -11,6 +11,10 @@ workspace = true
 [lib]
 path = "src/edit_prediction_context.rs"
 
+[[example]]
+name = "zeta_context"
+path = "examples/zeta_context.rs"
+
 [dependencies]
 anyhow.workspace = true
 arrayvec.workspace = true
@@ -19,17 +23,23 @@ gpui.workspace = true
 language.workspace = true
 log.workspace = true
 project.workspace = true
+regex.workspace = true
+serde.workspace = true
 slotmap.workspace = true
+strum.workspace = true
 text.workspace = true
 tree-sitter.workspace = true
 util.workspace = true
 workspace-hack.workspace = true
+itertools.workspace = true
 
 [dev-dependencies]
+clap.workspace = true
 futures.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
 indoc.workspace = true
 language = { workspace = true, features = ["test-support"] }
+ordered-float.workspace = true
 pretty_assertions.workspace = true
 project = {workspace= true, features = ["test-support"]}
 serde_json.workspace = true

crates/edit_prediction_context/examples/zeta_context.rs 🔗

@@ -0,0 +1,289 @@
+use anyhow::{Result, anyhow};
+use clap::{Parser, Subcommand};
+use ordered_float::OrderedFloat;
+use serde_json::json;
+use std::fmt::Display;
+use std::io::Write;
+use std::path::Path;
+use std::str::FromStr;
+use std::{path::PathBuf, sync::Arc};
+
+#[derive(Parser, Debug)]
+#[command(name = "zeta_context")]
+struct Args {
+    #[command(subcommand)]
+    command: Command,
+    #[arg(long, default_value_t = FileOrStdio::Stdio)]
+    log: FileOrStdio,
+}
+
+#[derive(Subcommand, Debug)]
+enum Command {
+    ShowIndex {
+        directory: PathBuf,
+    },
+    NearbyReferences {
+        cursor_position: SourceLocation,
+        #[arg(long, default_value_t = 10)]
+        context_lines: u32,
+    },
+
+    Run {
+        directory: PathBuf,
+        cursor_position: CursorPosition,
+        #[arg(long, default_value_t = 2048)]
+        prompt_limit: usize,
+        #[arg(long)]
+        output_scores: Option<FileOrStdio>,
+        #[command(flatten)]
+        excerpt_options: ExcerptOptions,
+    },
+}
+
+#[derive(Clone, Debug)]
+enum CursorPosition {
+    Random,
+    Specific(SourceLocation),
+}
+
+impl CursorPosition {
+    fn to_source_location_within(
+        &self,
+        languages: &[Arc<Language>],
+        directory: &Path,
+    ) -> SourceLocation {
+        match self {
+            CursorPosition::Random => {
+                let entries = ignore::Walk::new(directory)
+                    .filter_map(|result| result.ok())
+                    .filter(|entry| language_for_file(languages, entry.path()).is_some())
+                    .collect::<Vec<_>>();
+                let selected_entry_ix = rand::random_range(0..entries.len());
+                let path = entries[selected_entry_ix].path().to_path_buf();
+                let source = std::fs::read_to_string(&path).unwrap();
+                let offset = rand::random_range(0..source.len());
+                let point = point_from_offset(&source, offset);
+                let source_location = SourceLocation { path, point };
+                log::info!("Selected random cursor position: {source_location}");
+                source_location
+            }
+            CursorPosition::Specific(location) => location.clone(),
+        }
+    }
+}
+
+impl Display for CursorPosition {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            CursorPosition::Random => write!(f, "random"),
+            CursorPosition::Specific(location) => write!(f, "{}", &location),
+        }
+    }
+}
+
+impl FromStr for CursorPosition {
+    type Err = anyhow::Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        match s {
+            "random" => Ok(CursorPosition::Random),
+            _ => Ok(CursorPosition::Specific(SourceLocation::from_str(s)?)),
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+enum FileOrStdio {
+    File(PathBuf),
+    Stdio,
+}
+
+impl FileOrStdio {
+    #[allow(dead_code)]
+    fn read_to_string(&self) -> Result<String, std::io::Error> {
+        match self {
+            FileOrStdio::File(path) => std::fs::read_to_string(path),
+            FileOrStdio::Stdio => std::io::read_to_string(std::io::stdin()),
+        }
+    }
+
+    fn write_file_or_stdout(&self) -> Result<Box<dyn Write + Send + 'static>, std::io::Error> {
+        match self {
+            FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)),
+            FileOrStdio::Stdio => Ok(Box::new(std::io::stdout())),
+        }
+    }
+
+    fn write_file_or_stderr(
+        &self,
+    ) -> Result<Box<dyn std::io::Write + Send + 'static>, std::io::Error> {
+        match self {
+            FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)),
+            FileOrStdio::Stdio => Ok(Box::new(std::io::stderr())),
+        }
+    }
+}
+
+impl Display for FileOrStdio {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            FileOrStdio::File(path) => write!(f, "{}", path.display()),
+            FileOrStdio::Stdio => write!(f, "-"),
+        }
+    }
+}
+
+impl FromStr for FileOrStdio {
+    type Err = <PathBuf as FromStr>::Err;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        match s {
+            "-" => Ok(Self::Stdio),
+            _ => Ok(Self::File(PathBuf::from_str(s)?)),
+        }
+    }
+}
+
+fn main() -> Result<()> {
+    let args = ZetaContextArgs::parse();
+    env_logger::Builder::from_default_env()
+        .target(env_logger::Target::Pipe(args.log.write_file_or_stderr()?))
+        .init();
+    let languages = load_languages();
+    match &args.command {
+        Command::ShowIndex { directory } => {
+            /*
+            let directory = directory.canonicalize()?;
+            let index = IdentifierIndex::index_path(&languages, &directory)?;
+            for ((identifier, language_name), files) in &index.identifier_to_definitions {
+                println!("\n{} ({})", identifier.0, language_name.0);
+                for (file, definitions) in files {
+                    println!("  {:?}", file);
+                    for definition in definitions {
+                        println!("    {}", definition.path_string(&index));
+                    }
+                }
+            }
+            */
+            Ok(())
+        }
+
+        Command::NearbyReferences {
+            cursor_position,
+            context_lines,
+        } => {
+            /*
+            let (language, source, tree) = parse_file(&languages, &cursor_position.path)?;
+            let start_offset = offset_from_point(
+                &source,
+                Point::new(cursor_position.point.row.saturating_sub(*context_lines), 0),
+            );
+            let end_offset = offset_from_point(
+                &source,
+                Point::new(cursor_position.point.row + context_lines, 0),
+            );
+            let references = local_identifiers(
+                ReferenceRegion::Nearby,
+                &language,
+                &tree,
+                &source,
+                start_offset..end_offset,
+            );
+            for reference in references {
+                println!(
+                    "{:?} {}",
+                    point_range_from_offset_range(&source, reference.range),
+                    reference.identifier.0,
+                );
+            }
+            */
+            Ok(())
+        }
+
+        Command::Run {
+            directory,
+            cursor_position,
+            prompt_limit,
+            output_scores,
+            excerpt_options,
+        } => {
+            let directory = directory.canonicalize()?;
+            let index = IdentifierIndex::index_path(&languages, &directory)?;
+            let cursor_position = cursor_position.to_source_location_within(&languages, &directory);
+            let excerpt_file: Arc<Path> = cursor_position.path.as_path().into();
+            let (language, source, tree) = parse_file(&languages, &excerpt_file)?;
+            let cursor_offset = offset_from_point(&source, cursor_position.point);
+            let Some(excerpt_ranges) = ExcerptRangesInput {
+                language: &language,
+                tree: &tree,
+                source: &source,
+                cursor_offset,
+                options: excerpt_options,
+            }
+            .select() else {
+                return Err(anyhow!("line containing cursor does not fit within window"));
+            };
+            let mut snippets = gather_snippets(
+                &language,
+                &index,
+                &tree,
+                &excerpt_file,
+                &source,
+                excerpt_ranges.clone(),
+                cursor_offset,
+            );
+            let planned_prompt = PromptPlanner::populate(
+                &index,
+                snippets.clone(),
+                excerpt_file,
+                excerpt_ranges.clone(),
+                cursor_offset,
+                *prompt_limit,
+                &directory,
+            );
+            let prompt_string = planned_prompt.to_prompt_string(&index);
+            println!("{}", &prompt_string);
+
+            if let Some(output_scores) = output_scores {
+                snippets.sort_by_key(|snippet| OrderedFloat(-snippet.scores.signature));
+                let writer = output_scores.write_file_or_stdout()?;
+                serde_json::to_writer_pretty(
+                    writer,
+                    &snippets
+                        .into_iter()
+                        .map(|snippet| {
+                            json!({
+                                "file": snippet.definition_file,
+                                "symbol_path": snippet.definition.path_string(&index),
+                                "signature_score": snippet.scores.signature,
+                                "definition_score": snippet.scores.definition,
+                                "signature_score_density": snippet.score_density(&index, SnippetStyle::Signature),
+                                "definition_score_density": snippet.score_density(&index, SnippetStyle::Definition),
+                                "score_components": snippet.score_components
+                            })
+                        })
+                        .collect::<Vec<_>>(),
+                )?;
+            }
+
+            let actual_window_size = range_size(excerpt_ranges.excerpt_range);
+            if actual_window_size > excerpt_options.window_max_bytes {
+                let exceeded_amount = actual_window_size - excerpt_options.window_max_bytes;
+                if exceeded_amount as f64 / excerpt_options.window_max_bytes as f64 > 0.05 {
+                    log::error!("Exceeded max main excerpt size by {exceeded_amount} bytes");
+                }
+            }
+
+            if prompt_string.len() > *prompt_limit {
+                let exceeded_amount = prompt_string.len() - *prompt_limit;
+                if exceeded_amount as f64 / *prompt_limit as f64 > 0.1 {
+                    log::error!(
+                        "Exceeded max prompt size of {prompt_limit} bytes by {exceeded_amount} bytes"
+                    );
+                }
+            }
+
+            Ok(())
+        }
+    }
+}

crates/edit_prediction_context/src/scored_declaration.rs 🔗

@@ -0,0 +1,311 @@
+use itertools::Itertools as _;
+use serde::Serialize;
+use std::collections::HashMap;
+use std::ops::Range;
+use std::path::Path;
+use std::sync::Arc;
+use strum::EnumIter;
+use tree_sitter::{QueryCursor, StreamingIterator, Tree};
+
+use crate::{Declaration, outline::Identifier};
+
+#[derive(Clone, Debug)]
+pub struct ScoredSnippet {
+    #[allow(dead_code)]
+    pub identifier: Identifier,
+    pub definition_file: Arc<Path>,
+    pub definition: OutlineItem,
+    pub score_components: ScoreInputs,
+    pub scores: Scores,
+}
+
+// TODO: Consider having "Concise" style corresponding to `concise_text`
+#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub enum SnippetStyle {
+    Signature,
+    Definition,
+}
+
+impl ScoredSnippet {
+    /// Returns the score for this snippet with the specified style.
+    pub fn score(&self, style: SnippetStyle) -> f32 {
+        match style {
+            SnippetStyle::Signature => self.scores.signature,
+            SnippetStyle::Definition => self.scores.definition,
+        }
+    }
+
+    /// Returns the byte range for the snippet with the specified style. For `Signature` this is the
+    /// signature_range expanded to line boundaries. For `Definition` this is the item_range expanded to
+    /// line boundaries (similar to slice_at_line_boundaries).
+    pub fn line_range(
+        &self,
+        identifier_index: &IdentifierIndex,
+        style: SnippetStyle,
+    ) -> Range<usize> {
+        let source = identifier_index
+            .path_to_source
+            .get(&self.definition_file)
+            .unwrap();
+
+        let base_range = match style {
+            SnippetStyle::Signature => self.definition.signature_range.clone(),
+            SnippetStyle::Definition => self.definition.item_range.clone(),
+        };
+
+        expand_range_to_line_boundaries(source, base_range)
+    }
+
+    pub fn score_density(&self, identifier_index: &IdentifierIndex, style: SnippetStyle) -> f32 {
+        self.score(style) / range_size(self.line_range(identifier_index, style)) as f32
+    }
+}
+
+fn scored_snippets(
+    language: &Language,
+    index: &IdentifierIndex,
+    source: &str,
+    reference_file: &Path,
+    references: Vec<Reference>,
+    cursor_offset: usize,
+    excerpt_range: Range<usize>,
+) -> Vec<ScoredSnippet> {
+    let cursor = point_from_offset(source, cursor_offset);
+
+    let containing_range_identifier_occurrences =
+        IdentifierOccurrences::within_string(&source[excerpt_range.clone()]);
+
+    let start_point = Point::new(cursor.row.saturating_sub(2), 0);
+    let end_point = Point::new(cursor.row + 1, 0);
+    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
+        &source[offset_from_point(source, start_point)..offset_from_point(source, end_point)],
+    );
+
+    let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
+    for reference in references {
+        identifier_to_references
+            .entry(reference.identifier.clone())
+            .or_insert_with(Vec::new)
+            .push(reference);
+    }
+
+    identifier_to_references
+        .into_iter()
+        .flat_map(|(identifier, references)| {
+            let Some(definitions) = index
+                .identifier_to_definitions
+                .get(&(identifier.clone(), language.name.clone()))
+            else {
+                return Vec::new();
+            };
+            let definition_count = definitions.len();
+            let definition_file_count = definitions.keys().len();
+
+            definitions
+                .iter_all()
+                .flat_map(|(definition_file, file_definitions)| {
+                    let same_file_definition_count = file_definitions.len();
+                    let is_same_file = reference_file == definition_file.as_ref();
+                    file_definitions
+                        .iter()
+                        .filter(|definition| {
+                            !is_same_file
+                                || !range_intersection(&definition.item_range, &excerpt_range)
+                                    .is_some()
+                        })
+                        .filter_map(|definition| {
+                            let definition_line_distance = if is_same_file {
+                                let definition_line =
+                                    point_from_offset(source, definition.item_range.start).row;
+                                (cursor.row as i32 - definition_line as i32).abs() as u32
+                            } else {
+                                0
+                            };
+                            Some((definition_line_distance, definition))
+                        })
+                        .sorted_by_key(|&(distance, _)| distance)
+                        .enumerate()
+                        .map(
+                            |(
+                                definition_line_distance_rank,
+                                (definition_line_distance, definition),
+                            )| {
+                                score_snippet(
+                                    index,
+                                    source,
+                                    &identifier,
+                                    &references,
+                                    definition_file.clone(),
+                                    definition.clone(),
+                                    is_same_file,
+                                    definition_line_distance,
+                                    definition_line_distance_rank,
+                                    same_file_definition_count,
+                                    definition_count,
+                                    definition_file_count,
+                                    &containing_range_identifier_occurrences,
+                                    &adjacent_identifier_occurrences,
+                                    cursor,
+                                )
+                            },
+                        )
+                        .collect::<Vec<_>>()
+                })
+                .collect::<Vec<_>>()
+        })
+        .flatten()
+        .collect::<Vec<_>>()
+}
+
+fn score_snippet(
+    index: &IdentifierIndex,
+    reference_source: &str,
+    identifier: &Identifier,
+    references: &Vec<Reference>,
+    definition_file: Arc<Path>,
+    definition: OutlineItem,
+    is_same_file: bool,
+    definition_line_distance: u32,
+    definition_line_distance_rank: usize,
+    same_file_definition_count: usize,
+    definition_count: usize,
+    definition_file_count: usize,
+    containing_range_identifier_occurrences: &IdentifierOccurrences,
+    adjacent_identifier_occurrences: &IdentifierOccurrences,
+    cursor: Point,
+) -> Option<ScoredSnippet> {
+    let is_referenced_nearby = references
+        .iter()
+        .any(|r| r.reference_region == ReferenceRegion::Nearby);
+    let is_referenced_in_breadcrumb = references
+        .iter()
+        .any(|r| r.reference_region == ReferenceRegion::Breadcrumb);
+    let reference_count = references.len();
+    let reference_line_distance = references
+        .iter()
+        .map(|r| {
+            let reference_line = point_from_offset(reference_source, r.range.start).row as i32;
+            (cursor.row as i32 - reference_line).abs() as u32
+        })
+        .min()
+        .unwrap();
+
+    let definition_source = index.path_to_source.get(&definition_file).unwrap();
+    let item_source_occurrences =
+        IdentifierOccurrences::within_string(definition.item(&definition_source));
+    let item_signature_occurrences =
+        IdentifierOccurrences::within_string(definition.signature(&definition_source));
+    let containing_range_vs_item_jaccard = jaccard_similarity(
+        containing_range_identifier_occurrences,
+        &item_source_occurrences,
+    );
+    let containing_range_vs_signature_jaccard = jaccard_similarity(
+        containing_range_identifier_occurrences,
+        &item_signature_occurrences,
+    );
+    let adjacent_vs_item_jaccard =
+        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
+    let adjacent_vs_signature_jaccard =
+        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
+        containing_range_identifier_occurrences,
+        &item_source_occurrences,
+    );
+    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
+        containing_range_identifier_occurrences,
+        &item_signature_occurrences,
+    );
+    let adjacent_vs_item_weighted_overlap =
+        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
+    let adjacent_vs_signature_weighted_overlap =
+        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+    let score_components = ScoreInputs {
+        is_same_file,
+        is_referenced_nearby,
+        is_referenced_in_breadcrumb,
+        reference_line_distance,
+        definition_line_distance,
+        definition_line_distance_rank,
+        reference_count,
+        same_file_definition_count,
+        definition_count,
+        definition_file_count,
+        containing_range_vs_item_jaccard,
+        containing_range_vs_signature_jaccard,
+        adjacent_vs_item_jaccard,
+        adjacent_vs_signature_jaccard,
+        containing_range_vs_item_weighted_overlap,
+        containing_range_vs_signature_weighted_overlap,
+        adjacent_vs_item_weighted_overlap,
+        adjacent_vs_signature_weighted_overlap,
+    };
+
+    Some(ScoredSnippet {
+        identifier: identifier.clone(),
+        definition_file,
+        definition,
+        scores: score_components.score(),
+        score_components,
+    })
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct ScoreInputs {
+    pub is_same_file: bool,
+    pub is_referenced_nearby: bool,
+    pub is_referenced_in_breadcrumb: bool,
+    pub reference_count: usize,
+    pub same_file_definition_count: usize,
+    pub definition_count: usize,
+    pub definition_file_count: usize,
+    pub reference_line_distance: u32,
+    pub definition_line_distance: u32,
+    pub definition_line_distance_rank: usize,
+    pub containing_range_vs_item_jaccard: f32,
+    pub containing_range_vs_signature_jaccard: f32,
+    pub adjacent_vs_item_jaccard: f32,
+    pub adjacent_vs_signature_jaccard: f32,
+    pub containing_range_vs_item_weighted_overlap: f32,
+    pub containing_range_vs_signature_weighted_overlap: f32,
+    pub adjacent_vs_item_weighted_overlap: f32,
+    pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct Scores {
+    pub signature: f32,
+    pub definition: f32,
+}
+
+impl ScoreInputs {
+    fn score(&self) -> Scores {
+        // Score related to how likely this is the correct definition, range 0 to 1
+        let accuracy_score = if self.is_same_file {
+            // TODO: use definition_line_distance_rank
+            (0.5 / self.same_file_definition_count as f32)
+                + (0.5 / self.definition_file_count as f32)
+        } else {
+            1.0 / self.definition_count as f32
+        };
+
+        // Score related to the distance between the reference and cursor, range 0 to 1
+        let distance_score = if self.is_referenced_nearby {
+            1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+        } else {
+            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
+            0.5
+        };
+
+        // For now instead of linear combination, the scores are just multiplied together.
+        let combined_score = 10.0 * accuracy_score * distance_score;
+
+        Scores {
+            signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+            // definition score gets boosted both by being multipled by 2 and by there being more
+            // weighted overlap.
+            definition: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+        }
+    }
+}