Progress preparing new cloud request + using index in excerpt selection

Michael Sloan and Agus created

Co-authored-by: Agus <agus@zed.dev>

Change summary

Cargo.lock                                                    |   1 
crates/cloud_llm_client/src/cloud_llm_client.rs               |   2 
crates/cloud_llm_client/src/predict_edits_v3.rs               | 123 +++
crates/edit_prediction_context/Cargo.toml                     |   1 
crates/edit_prediction_context/src/declaration.rs             |  24 
crates/edit_prediction_context/src/declaration_scoring.rs     |  77 -
crates/edit_prediction_context/src/edit_prediction_context.rs | 168 ++++
crates/edit_prediction_context/src/excerpt.rs                 |  85 -
crates/edit_prediction_context/src/reference.rs               |   4 
crates/edit_prediction_context/src/syntax_index.rs            |  62 +
crates/edit_prediction_tools/src/edit_prediction_tools.rs     |  11 
11 files changed, 408 insertions(+), 150 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5174,6 +5174,7 @@ dependencies = [
  "anyhow",
  "arrayvec",
  "clap",
+ "cloud_llm_client",
  "collections",
  "futures 0.3.31",
  "gpui",

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -0,0 +1,123 @@
+use serde::{Deserialize, Serialize};
+use std::ops::Range;
+
+use crate::PredictEditsGitInfo;
+
+// TODO: snippet ordering within file / relative to excerpt
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Body {
+    pub excerpt: String,
+    /// Within `signatures`
+    pub excerpt_parent: Option<usize>,
+    pub signatures: Vec<Signature>,
+    pub referenced_declarations: Vec<ReferencedDeclaration>,
+    pub events: Vec<Event>,
+    #[serde(default)]
+    pub can_collect_data: bool,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
+    /// Info about the git repository state, only present when can_collect_data is true.
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub git_info: Option<PredictEditsGitInfo>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub enum Event {}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Signature {
+    pub text: String,
+    pub text_is_truncated: bool,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub parent_index: Option<usize>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ReferencedDeclaration {
+    pub text: String,
+    pub text_is_truncated: bool,
+    /// Range within `text`
+    pub signature_range: Range<usize>,
+    /// Index within `signatures`.
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub parent_index: Option<usize>,
+    pub score_components: ScoreComponents,
+    pub signature_score: f32,
+    pub declaration_score: f32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ScoreComponents {
+    pub is_same_file: bool,
+    pub is_referenced_nearby: bool,
+    pub is_referenced_in_breadcrumb: bool,
+    pub reference_count: usize,
+    pub same_file_declaration_count: usize,
+    pub declaration_count: usize,
+    pub reference_line_distance: u32,
+    pub declaration_line_distance: u32,
+    pub declaration_line_distance_rank: usize,
+    pub containing_range_vs_item_jaccard: f32,
+    pub containing_range_vs_signature_jaccard: f32,
+    pub adjacent_vs_item_jaccard: f32,
+    pub adjacent_vs_signature_jaccard: f32,
+    pub containing_range_vs_item_weighted_overlap: f32,
+    pub containing_range_vs_signature_weighted_overlap: f32,
+    pub adjacent_vs_item_weighted_overlap: f32,
+    pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+/*
+#[derive(Debug, Clone)]
+pub struct SerializedJson<T> {
+    raw: Box<RawValue>,
+    _phantom: PhantomData<T>,
+}
+
+impl<T> SerializedJson<T>
+where
+    T: Serialize + for<'de> Deserialize<'de>,
+{
+    pub fn new(value: &T) -> Result<Self, serde_json::Error> {
+        Ok(SerializedJson {
+            raw: serde_json::value::to_raw_value(value)?,
+            _phantom: PhantomData,
+        })
+    }
+
+    pub fn deserialize(&self) -> Result<T, serde_json::Error> {
+        serde_json::from_str(self.raw.get())
+    }
+
+    pub fn as_raw(&self) -> &RawValue {
+        &self.raw
+    }
+
+    pub fn into_raw(self) -> Box<RawValue> {
+        self.raw
+    }
+}
+
+impl<T> Serialize for SerializedJson<T> {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        self.raw.serialize(serializer)
+    }
+}
+
+impl<'de, T> Deserialize<'de> for SerializedJson<T> {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let raw = Box::<RawValue>::deserialize(deserializer)?;
+        Ok(SerializedJson {
+            raw,
+            _phantom: PhantomData,
+        })
+    }
+}
+*/

crates/edit_prediction_context/Cargo.toml 🔗

@@ -14,6 +14,7 @@ path = "src/edit_prediction_context.rs"
 [dependencies]
 anyhow.workspace = true
 arrayvec.workspace = true
+cloud_llm_client.workspace = true
 collections.workspace = true
 futures.workspace = true
 gpui.workspace = true

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -41,6 +41,20 @@ impl Declaration {
         }
     }
 
+    pub fn parent(&self) -> Option<DeclarationId> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.parent,
+            Declaration::Buffer { declaration, .. } => declaration.parent,
+        }
+    }
+
+    pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
+        match self {
+            Declaration::File { .. } => None,
+            Declaration::Buffer { declaration, .. } => Some(declaration),
+        }
+    }
+
     pub fn project_entry_id(&self) -> ProjectEntryId {
         match self {
             Declaration::File {
@@ -83,6 +97,16 @@ impl Declaration {
             ),
         }
     }
+
+    pub fn signature_range_in_item_text(&self) -> Range<usize> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(),
+            Declaration::Buffer { declaration, .. } => {
+                declaration.signature_range.start - declaration.item_range.start
+                    ..declaration.signature_range.end - declaration.item_range.start
+            }
+        }
+    }
 }
 
 fn expand_range_to_line_boundaries_and_truncate(

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -1,10 +1,11 @@
+use cloud_llm_client::predict_edits_v3::ScoreComponents;
 use itertools::Itertools as _;
 use language::BufferSnapshot;
 use ordered_float::OrderedFloat;
 use serde::Serialize;
 use std::{collections::HashMap, ops::Range};
 use strum::EnumIter;
-use text::{OffsetRangeExt, Point, ToPoint};
+use text::{Point, ToPoint};
 
 use crate::{
     Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
@@ -23,7 +24,7 @@ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 pub struct ScoredSnippet {
     pub identifier: Identifier,
     pub declaration: Declaration,
-    pub score_components: ScoreInputs,
+    pub score_components: ScoreComponents,
     pub scores: Scores,
 }
 
@@ -90,8 +91,8 @@ pub fn scored_snippets(
             let declaration_count = declarations.len();
 
             declarations
-                .iter()
-                .filter_map(|declaration| match declaration {
+                .into_iter()
+                .filter_map(|(declaration_id, declaration)| match declaration {
                     Declaration::Buffer {
                         buffer_id,
                         declaration: buffer_declaration,
@@ -100,24 +101,29 @@ pub fn scored_snippets(
                         let is_same_file = buffer_id == &current_buffer.remote_id();
 
                         if is_same_file {
-                            range_intersection(
-                                &buffer_declaration.item_range.to_offset(&current_buffer),
-                                &excerpt.range,
-                            )
-                            .is_none()
-                            .then(|| {
+                            let overlaps_excerpt =
+                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
+                                    .is_some();
+                            if overlaps_excerpt
+                                || excerpt
+                                    .parent_declarations
+                                    .iter()
+                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
+                            {
+                                None
+                            } else {
                                 let declaration_line = buffer_declaration
                                     .item_range
                                     .start
                                     .to_point(current_buffer)
                                     .row;
-                                (
+                                Some((
                                     true,
                                     (cursor_point.row as i32 - declaration_line as i32)
                                         .unsigned_abs(),
                                     declaration,
-                                )
-                            })
+                                ))
+                            }
                         } else {
                             Some((false, u32::MAX, declaration))
                         }
@@ -238,7 +244,7 @@ fn score_snippet(
     let adjacent_vs_signature_weighted_overlap =
         weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
 
-    let score_components = ScoreInputs {
+    let score_components = ScoreComponents {
         is_same_file,
         is_referenced_nearby,
         is_referenced_in_breadcrumb,
@@ -261,51 +267,30 @@ fn score_snippet(
     Some(ScoredSnippet {
         identifier: identifier.clone(),
         declaration: declaration,
-        scores: score_components.score(),
+        scores: Scores::score(&score_components),
         score_components,
     })
 }
 
-#[derive(Clone, Debug, Serialize)]
-pub struct ScoreInputs {
-    pub is_same_file: bool,
-    pub is_referenced_nearby: bool,
-    pub is_referenced_in_breadcrumb: bool,
-    pub reference_count: usize,
-    pub same_file_declaration_count: usize,
-    pub declaration_count: usize,
-    pub reference_line_distance: u32,
-    pub declaration_line_distance: u32,
-    pub declaration_line_distance_rank: usize,
-    pub containing_range_vs_item_jaccard: f32,
-    pub containing_range_vs_signature_jaccard: f32,
-    pub adjacent_vs_item_jaccard: f32,
-    pub adjacent_vs_signature_jaccard: f32,
-    pub containing_range_vs_item_weighted_overlap: f32,
-    pub containing_range_vs_signature_weighted_overlap: f32,
-    pub adjacent_vs_item_weighted_overlap: f32,
-    pub adjacent_vs_signature_weighted_overlap: f32,
-}
-
 #[derive(Clone, Debug, Serialize)]
 pub struct Scores {
     pub signature: f32,
     pub declaration: f32,
 }
 
-impl ScoreInputs {
-    fn score(&self) -> Scores {
+impl Scores {
+    fn score(components: &ScoreComponents) -> Scores {
         // Score related to how likely this is the correct declaration, range 0 to 1
-        let accuracy_score = if self.is_same_file {
+        let accuracy_score = if components.is_same_file {
             // TODO: use declaration_line_distance_rank
-            1.0 / self.same_file_declaration_count as f32
+            1.0 / components.same_file_declaration_count as f32
         } else {
-            1.0 / self.declaration_count as f32
+            1.0 / components.declaration_count as f32
         };
 
         // Score related to the distance between the reference and cursor, range 0 to 1
-        let distance_score = if self.is_referenced_nearby {
-            1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+        let distance_score = if components.is_referenced_nearby {
+            1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
         } else {
             // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
             0.5
@@ -315,10 +300,12 @@ impl ScoreInputs {
         let combined_score = 10.0 * accuracy_score * distance_score;
 
         Scores {
-            signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+            signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
             // declaration score gets boosted both by being multiplied by 2 and by there being more
             // weighted overlap.
-            declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+            declaration: 2.0
+                * combined_score
+                * components.containing_range_vs_item_weighted_overlap,
         }
     }
 }

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -6,8 +6,8 @@ mod reference;
 mod syntax_index;
 mod text_similarity;
 
-use std::time::Instant;
-
+use cloud_llm_client::predict_edits_v3::{self, Signature};
+use collections::HashMap;
 pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
 pub use declaration_scoring::SnippetStyle;
 pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
@@ -18,14 +18,17 @@ pub use reference::references_in_excerpt;
 pub use syntax_index::SyntaxIndex;
 use text::{Point, ToOffset as _};
 
-use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
+use crate::{
+    declaration::DeclarationId,
+    declaration_scoring::{ScoredSnippet, scored_snippets},
+    syntax_index::SyntaxIndexState,
+};
 
 #[derive(Debug)]
 pub struct EditPredictionContext {
     pub excerpt: EditPredictionExcerpt,
     pub excerpt_text: EditPredictionExcerptText,
     pub snippets: Vec<ScoredSnippet>,
-    pub retrieval_duration: std::time::Duration,
 }
 
 impl EditPredictionContext {
@@ -36,34 +39,135 @@ impl EditPredictionContext {
         syntax_index: Entity<SyntaxIndex>,
         cx: &mut App,
     ) -> Task<Option<Self>> {
-        let start = Instant::now();
         let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
         cx.background_spawn(async move {
             let index_state = index_state.lock().await;
+            Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
+        })
+    }
 
-            let excerpt =
-                EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
-            let excerpt_text = excerpt.text(&buffer);
-            let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
-            let cursor_offset = cursor_point.to_offset(&buffer);
-
-            let snippets = scored_snippets(
-                &index_state,
-                &excerpt,
-                &excerpt_text,
-                references,
-                cursor_offset,
-                &buffer,
-            );
-
-            Some(Self {
-                excerpt,
-                excerpt_text,
-                snippets,
-                retrieval_duration: start.elapsed(),
-            })
+    fn gather_context(
+        cursor_point: Point,
+        buffer: BufferSnapshot,
+        excerpt_options: EditPredictionExcerptOptions,
+        index_state: &SyntaxIndexState,
+    ) -> Option<Self> {
+        let excerpt = EditPredictionExcerpt::select_from_buffer(
+            cursor_point,
+            &buffer,
+            &excerpt_options,
+            Some(index_state),
+        )?;
+        let excerpt_text = excerpt.text(&buffer);
+        let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
+        let cursor_offset = cursor_point.to_offset(&buffer);
+
+        let snippets = scored_snippets(
+            &index_state,
+            &excerpt,
+            &excerpt_text,
+            references,
+            cursor_offset,
+            &buffer,
+        );
+
+        Some(Self {
+            excerpt,
+            excerpt_text,
+            snippets,
         })
     }
+
+    pub fn cloud_request(
+        cursor_point: Point,
+        buffer: BufferSnapshot,
+        excerpt_options: EditPredictionExcerptOptions,
+        syntax_index: Entity<SyntaxIndex>,
+        cx: &mut App,
+    ) -> Task<Option<predict_edits_v3::Body>> {
+        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
+        cx.background_spawn(async move {
+            let index_state = index_state.lock().await;
+            Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
+                .map(|context| context.into_cloud_request(&index_state))
+        })
+    }
+
+    pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body {
+        let mut signatures = Vec::new();
+        let mut declaration_to_signature_index = HashMap::default();
+        let mut referenced_declarations = Vec::new();
+        let excerpt_parent = self
+            .excerpt
+            .parent_declarations
+            .last()
+            .and_then(|(parent, _)| {
+                add_signature(
+                    *parent,
+                    &mut declaration_to_signature_index,
+                    &mut signatures,
+                    index,
+                )
+            });
+        for snippet in self.snippets {
+            let parent_index = snippet.declaration.parent().and_then(|parent| {
+                add_signature(
+                    parent,
+                    &mut declaration_to_signature_index,
+                    &mut signatures,
+                    index,
+                )
+            });
+            let (text, text_is_truncated) = snippet.declaration.item_text();
+            referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
+                text: text.into(),
+                text_is_truncated,
+                signature_range: snippet.declaration.signature_range_in_item_text(),
+                parent_index,
+                score_components: snippet.score_components,
+                signature_score: snippet.scores.signature,
+                declaration_score: snippet.scores.declaration,
+            });
+        }
+        predict_edits_v3::Body {
+            excerpt: self.excerpt_text.body,
+            referenced_declarations,
+            signatures,
+            excerpt_parent,
+            // todo!
+            events: vec![],
+            can_collect_data: false,
+            diagnostic_groups: None,
+            git_info: None,
+        }
+    }
+}
+
+fn add_signature(
+    declaration_id: DeclarationId,
+    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
+    signatures: &mut Vec<Signature>,
+    index: &SyntaxIndexState,
+) -> Option<usize> {
+    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
+        return Some(*signature_index);
+    }
+    let Some(parent_declaration) = index.declaration(declaration_id) else {
+        log::error!("bug: missing parent declaration");
+        return None;
+    };
+    let parent_index = parent_declaration.parent().and_then(|parent| {
+        add_signature(parent, declaration_to_signature_index, signatures, index)
+    });
+    let (text, text_is_truncated) = parent_declaration.signature_text();
+    let signature_index = signatures.len();
+    signatures.push(Signature {
+        text: text.into(),
+        text_is_truncated,
+        parent_index,
+    });
+    declaration_to_signature_index.insert(declaration_id, signature_index);
+    Some(signature_index)
 }
 
 #[cfg(test)]
@@ -105,10 +209,9 @@ mod tests {
                     cursor_point,
                     buffer_snapshot,
                     EditPredictionExcerptOptions {
-                        max_bytes: 40,
+                        max_bytes: 60,
                         min_bytes: 10,
                         target_before_cursor_over_total_bytes: 0.5,
-                        include_parent_signatures: false,
                     },
                     index,
                     cx,
@@ -117,8 +220,13 @@ mod tests {
             .await
             .unwrap();
 
-        assert_eq!(context.snippets.len(), 1);
-        assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
+        let mut snippet_identifiers = context
+            .snippets
+            .iter()
+            .map(|snippet| snippet.identifier.name.as_ref())
+            .collect::<Vec<_>>();
+        snippet_identifiers.sort();
+        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
         drop(buffer);
     }
 

crates/edit_prediction_context/src/excerpt.rs 🔗

@@ -1,9 +1,11 @@
 use language::BufferSnapshot;
 use std::ops::Range;
-use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
+use text::{Point, ToOffset as _, ToPoint as _};
 use tree_sitter::{Node, TreeCursor};
 use util::RangeExt;
 
+use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
+
 // TODO:
 //
 // - Test parent signatures
@@ -27,14 +29,12 @@ pub struct EditPredictionExcerptOptions {
     pub min_bytes: usize,
     /// Target ratio of bytes before the cursor divided by total bytes in the window.
     pub target_before_cursor_over_total_bytes: f32,
-    /// Whether to include parent signatures
-    pub include_parent_signatures: bool,
 }
 
 #[derive(Debug, Clone)]
 pub struct EditPredictionExcerpt {
     pub range: Range<usize>,
-    pub parent_signature_ranges: Vec<Range<usize>>,
+    pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
     pub size: usize,
 }
 
@@ -50,9 +50,9 @@ impl EditPredictionExcerpt {
             .text_for_range(self.range.clone())
             .collect::<String>();
         let parent_signatures = self
-            .parent_signature_ranges
+            .parent_declarations
             .iter()
-            .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
+            .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
             .collect();
         EditPredictionExcerptText {
             body,
@@ -62,8 +62,9 @@ impl EditPredictionExcerpt {
 
     /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
     /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
-    /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
-    /// of parent outline items.
+    /// cursor.
+    ///
+    /// When `index` is provided, the excerpt will include the signatures of parent outline items.
     ///
     /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
     /// expansion.
@@ -73,6 +74,7 @@ impl EditPredictionExcerpt {
         query_point: Point,
         buffer: &BufferSnapshot,
         options: &EditPredictionExcerptOptions,
+        syntax_index: Option<&SyntaxIndexState>,
     ) -> Option<Self> {
         if buffer.len() <= options.max_bytes {
             log::debug!(
@@ -90,17 +92,9 @@ impl EditPredictionExcerpt {
             return None;
         }
 
-        // TODO: Don't compute text / annotation_range / skip converting to and from anchors.
-        let outline_items = if options.include_parent_signatures {
-            buffer
-                .outline_items_containing(query_range.clone(), false, None)
-                .into_iter()
-                .flat_map(|item| {
-                    Some(ExcerptOutlineItem {
-                        item_range: item.range.to_offset(&buffer),
-                        signature_range: item.signature_range?.to_offset(&buffer),
-                    })
-                })
+        let parent_declarations = if let Some(syntax_index) = syntax_index {
+            syntax_index
+                .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
                 .collect()
         } else {
             Vec::new()
@@ -109,7 +103,7 @@ impl EditPredictionExcerpt {
         let excerpt_selector = ExcerptSelector {
             query_offset,
             query_range,
-            outline_items: &outline_items,
+            parent_declarations: &parent_declarations,
             buffer,
             options,
         };
@@ -132,15 +126,15 @@ impl EditPredictionExcerpt {
         excerpt_selector.select_lines()
     }
 
-    fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
+    fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
         let size = range.len()
-            + parent_signature_ranges
+            + parent_declarations
                 .iter()
-                .map(|r| r.len())
+                .map(|(_, range)| range.len())
                 .sum::<usize>();
         Self {
             range,
-            parent_signature_ranges,
+            parent_declarations,
             size,
         }
     }
@@ -150,20 +144,14 @@ impl EditPredictionExcerpt {
             // this is an issue because parent_signature_ranges may be incorrect
             log::error!("bug: with_expanded_range called with disjoint range");
         }
-        let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
-        let mut size = new_range.len();
-        for range in &self.parent_signature_ranges {
-            if range.contains_inclusive(&new_range) {
+        let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
+        for (declaration_id, range) in &self.parent_declarations {
+            if !range.contains_inclusive(&new_range) {
                 break;
             }
-            parent_signature_ranges.push(range.clone());
-            size += range.len();
-        }
-        Self {
-            range: new_range,
-            parent_signature_ranges,
-            size,
+            parent_declarations.push((*declaration_id, range.clone()));
         }
+        Self::new(new_range, parent_declarations)
     }
 
     fn parent_signatures_size(&self) -> usize {
@@ -174,16 +162,11 @@ impl EditPredictionExcerpt {
 struct ExcerptSelector<'a> {
     query_offset: usize,
     query_range: Range<usize>,
-    outline_items: &'a [ExcerptOutlineItem],
+    parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
     buffer: &'a BufferSnapshot,
     options: &'a EditPredictionExcerptOptions,
 }
 
-struct ExcerptOutlineItem {
-    item_range: Range<usize>,
-    signature_range: Range<usize>,
-}
-
 impl<'a> ExcerptSelector<'a> {
     /// Finds the largest node that is smaller than the window size and contains `query_range`.
     fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
@@ -396,13 +379,13 @@ impl<'a> ExcerptSelector<'a> {
     }
 
     fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
-        let parent_signature_ranges = self
-            .outline_items
+        let parent_declarations = self
+            .parent_declarations
             .iter()
-            .filter(|item| item.item_range.contains_inclusive(&range))
-            .map(|item| item.signature_range.clone())
+            .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
+            .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
             .collect();
-        EditPredictionExcerpt::new(range, parent_signature_ranges)
+        EditPredictionExcerpt::new(range, parent_declarations)
     }
 
     /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -493,8 +476,9 @@ mod tests {
         let buffer = create_buffer(&text, cx);
         let cursor_point = cursor.to_point(&buffer);
 
-        let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
-            .expect("Should select an excerpt");
+        let excerpt =
+            EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
+                .expect("Should select an excerpt");
         pretty_assertions::assert_eq!(
             generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
             generate_marked_text(&text, &[expected_excerpt], false)
@@ -517,7 +501,6 @@ fn main() {
             max_bytes: 20,
             min_bytes: 10,
             target_before_cursor_over_total_bytes: 0.5,
-            include_parent_signatures: false,
         };
 
         check_example(options, text, cx);
@@ -541,7 +524,6 @@ fn bar() {}"#;
             max_bytes: 65,
             min_bytes: 10,
             target_before_cursor_over_total_bytes: 0.5,
-            include_parent_signatures: false,
         };
 
         check_example(options, text, cx);
@@ -561,7 +543,6 @@ fn main() {
             max_bytes: 50,
             min_bytes: 10,
             target_before_cursor_over_total_bytes: 0.5,
-            include_parent_signatures: false,
         };
 
         check_example(options, text, cx);
@@ -583,7 +564,6 @@ fn main() {
             max_bytes: 60,
             min_bytes: 45,
             target_before_cursor_over_total_bytes: 0.5,
-            include_parent_signatures: false,
         };
 
         check_example(options, text, cx);
@@ -608,7 +588,6 @@ fn main() {
             max_bytes: 120,
             min_bytes: 10,
             target_before_cursor_over_total_bytes: 0.6,
-            include_parent_signatures: false,
         };
 
         check_example(options, text, cx);

crates/edit_prediction_context/src/reference.rs 🔗

@@ -33,8 +33,8 @@ pub fn references_in_excerpt(
         snapshot,
     );
 
-    for (range, text) in excerpt
-        .parent_signature_ranges
+    for ((_, range), text) in excerpt
+        .parent_declarations
         .iter()
         .zip(excerpt_text.parent_signatures.iter())
     {

crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
 use collections::{HashMap, HashSet};
 use futures::lock::Mutex;
 use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
@@ -8,8 +6,11 @@ use project::buffer_store::{BufferStore, BufferStoreEvent};
 use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
 use project::{PathChange, Project, ProjectEntryId, ProjectPath};
 use slotmap::SlotMap;
+use std::iter;
+use std::ops::Range;
+use std::sync::Arc;
 use text::BufferId;
-use util::{debug_panic, some_or_debug_panic};
+use util::{RangeExt as _, debug_panic, some_or_debug_panic};
 
 use crate::declaration::{
     BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
@@ -432,7 +433,7 @@ impl SyntaxIndexState {
     pub fn declarations_for_identifier<const N: usize>(
         &self,
         identifier: &Identifier,
-    ) -> Vec<Declaration> {
+    ) -> Vec<(DeclarationId, &Declaration)> {
         // make sure to not have a large stack allocation
         assert!(N < 32);
 
@@ -454,7 +455,7 @@ impl SyntaxIndexState {
                     project_entry_id, ..
                 } => {
                     included_buffer_entry_ids.push(*project_entry_id);
-                    result.push(declaration.clone());
+                    result.push((*declaration_id, declaration));
                     if result.len() == N {
                         return Vec::new();
                     }
@@ -463,19 +464,19 @@ impl SyntaxIndexState {
                     project_entry_id, ..
                 } => {
                     if !included_buffer_entry_ids.contains(&project_entry_id) {
-                        file_declarations.push(declaration.clone());
+                        file_declarations.push((*declaration_id, declaration));
                     }
                 }
             }
         }
 
-        for declaration in file_declarations {
+        for (declaration_id, declaration) in file_declarations {
             match declaration {
                 Declaration::File {
                     project_entry_id, ..
                 } => {
                     if !included_buffer_entry_ids.contains(&project_entry_id) {
-                        result.push(declaration);
+                        result.push((declaration_id, declaration));
 
                         if result.len() == N {
                             return Vec::new();
@@ -489,6 +490,35 @@ impl SyntaxIndexState {
         result
     }
 
+    pub fn buffer_declarations_containing_range(
+        &self,
+        buffer_id: BufferId,
+        range: Range<usize>,
+    ) -> impl Iterator<Item = (DeclarationId, &BufferDeclaration)> {
+        let Some(buffer_state) = self.buffers.get(&buffer_id) else {
+            return itertools::Either::Left(iter::empty());
+        };
+
+        let iter = buffer_state
+            .declarations
+            .iter()
+            .filter_map(move |declaration_id| {
+                let Some(declaration) = self
+                    .declarations
+                    .get(*declaration_id)
+                    .and_then(|d| d.as_buffer())
+                else {
+                    log::error!("bug: missing buffer outline declaration");
+                    return None;
+                };
+                if declaration.item_range.contains_inclusive(&range) {
+                    return Some((*declaration_id, declaration));
+                }
+                return None;
+            });
+        itertools::Either::Right(iter)
+    }
+
     pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
         match declaration {
             Declaration::File {
@@ -553,11 +583,11 @@ mod tests {
             let decls = index_state.declarations_for_identifier::<8>(&main);
             assert_eq!(decls.len(), 2);
 
-            let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
+            let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
             assert_eq!(decl.identifier, main.clone());
             assert_eq!(decl.item_range_in_file, 32..280);
 
-            let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
+            let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx);
             assert_eq!(decl.identifier, main);
             assert_eq!(decl.item_range_in_file, 0..98);
         });
@@ -577,7 +607,7 @@ mod tests {
             let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
             assert_eq!(decls.len(), 1);
 
-            let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
+            let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
             assert_eq!(decl.identifier, test_process_data);
 
             let parent_id = decl.parent.unwrap();
@@ -618,7 +648,7 @@ mod tests {
             let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
             assert_eq!(decls.len(), 1);
 
-            let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+            let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
             assert_eq!(decl.identifier, test_process_data);
 
             let parent_id = decl.parent.unwrap();
@@ -676,11 +706,11 @@ mod tests {
             cx.update(|cx| {
                 let decls = index_state.declarations_for_identifier::<8>(&main);
                 assert_eq!(decls.len(), 2);
-                let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+                let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
                 assert_eq!(decl.identifier, main);
                 assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
 
-                expect_file_decl("a.rs", &decls[1], &project, cx);
+                expect_file_decl("a.rs", &decls[1].1, &project, cx);
             });
         }
 
@@ -695,8 +725,8 @@ mod tests {
         cx.update(|cx| {
             let decls = index_state.declarations_for_identifier::<8>(&main);
             assert_eq!(decls.len(), 2);
-            expect_file_decl("c.rs", &decls[0], &project, cx);
-            expect_file_decl("a.rs", &decls[1], &project, cx);
+            expect_file_decl("c.rs", &decls[0].1, &project, cx);
+            expect_file_decl("a.rs", &decls[1].1, &project, cx);
         });
     }
 

crates/edit_prediction_tools/src/edit_prediction_tools.rs 🔗

@@ -4,7 +4,7 @@ use std::{
     path::{Path, PathBuf},
     str::FromStr,
     sync::Arc,
-    time::Duration,
+    time::{Duration, Instant},
 };
 
 use collections::HashMap;
@@ -195,6 +195,8 @@ impl EditPredictionTools {
                     .timer(Duration::from_millis(50))
                     .await;
 
+                let mut start_time = None;
+
                 let Ok(task) = this.update(cx, |this, cx| {
                     fn number_input_value<T: FromStr + Default>(
                         input: &Entity<SingleLineInput>,
@@ -216,10 +218,10 @@ impl EditPredictionTools {
                             &this.cursor_context_ratio_input,
                             cx,
                         ),
-                        // TODO Display and add to options
-                        include_parent_signatures: false,
                     };
 
+                    start_time = Some(Instant::now());
+
                     EditPredictionContext::gather(
                         cursor_position,
                         current_buffer_snapshot,
@@ -243,6 +245,7 @@ impl EditPredictionTools {
                     .ok();
                     return;
                 };
+                let retrieval_duration = start_time.unwrap().elapsed();
 
                 let mut languages = HashMap::default();
                 for snippet in context.snippets.iter() {
@@ -320,7 +323,7 @@ impl EditPredictionTools {
 
                     this.last_context = Some(ContextState {
                         context_editor,
-                        retrieval_duration: context.retrieval_duration,
+                        retrieval_duration,
                     });
                     cx.notify();
                 })