Cargo.lock 🔗
@@ -5174,6 +5174,7 @@ dependencies = [
"anyhow",
"arrayvec",
"clap",
+ "cloud_llm_client",
"collections",
"futures 0.3.31",
"gpui",
Michael Sloan and Agus created
Co-authored-by: Agus <agus@zed.dev>
Cargo.lock | 1
crates/cloud_llm_client/src/cloud_llm_client.rs | 2
crates/cloud_llm_client/src/predict_edits_v3.rs | 123 +++
crates/edit_prediction_context/Cargo.toml | 1
crates/edit_prediction_context/src/declaration.rs | 24
crates/edit_prediction_context/src/declaration_scoring.rs | 77 -
crates/edit_prediction_context/src/edit_prediction_context.rs | 168 ++++
crates/edit_prediction_context/src/excerpt.rs | 85 -
crates/edit_prediction_context/src/reference.rs | 4
crates/edit_prediction_context/src/syntax_index.rs | 62 +
crates/edit_prediction_tools/src/edit_prediction_tools.rs | 11
11 files changed, 408 insertions(+), 150 deletions(-)
@@ -5174,6 +5174,7 @@ dependencies = [
"anyhow",
"arrayvec",
"clap",
+ "cloud_llm_client",
"collections",
"futures 0.3.31",
"gpui",
@@ -1,3 +1,5 @@
+pub mod predict_edits_v3;
+
use std::str::FromStr;
use std::sync::Arc;
@@ -0,0 +1,123 @@
+use serde::{Deserialize, Serialize};
+use std::ops::Range;
+
+use crate::PredictEditsGitInfo;
+
+// TODO: snippet ordering within file / relative to excerpt
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Body {
+ pub excerpt: String,
+ /// Within `signatures`
+ pub excerpt_parent: Option<usize>,
+ pub signatures: Vec<Signature>,
+ pub referenced_declarations: Vec<ReferencedDeclaration>,
+ pub events: Vec<Event>,
+ #[serde(default)]
+ pub can_collect_data: bool,
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
+ /// Info about the git repository state, only present when can_collect_data is true.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub git_info: Option<PredictEditsGitInfo>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub enum Event {}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Signature {
+ pub text: String,
+ pub text_is_truncated: bool,
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub parent_index: Option<usize>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ReferencedDeclaration {
+ pub text: String,
+ pub text_is_truncated: bool,
+ /// Range within `text`
+ pub signature_range: Range<usize>,
+ /// Index within `signatures`.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub parent_index: Option<usize>,
+ pub score_components: ScoreComponents,
+ pub signature_score: f32,
+ pub declaration_score: f32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ScoreComponents {
+ pub is_same_file: bool,
+ pub is_referenced_nearby: bool,
+ pub is_referenced_in_breadcrumb: bool,
+ pub reference_count: usize,
+ pub same_file_declaration_count: usize,
+ pub declaration_count: usize,
+ pub reference_line_distance: u32,
+ pub declaration_line_distance: u32,
+ pub declaration_line_distance_rank: usize,
+ pub containing_range_vs_item_jaccard: f32,
+ pub containing_range_vs_signature_jaccard: f32,
+ pub adjacent_vs_item_jaccard: f32,
+ pub adjacent_vs_signature_jaccard: f32,
+ pub containing_range_vs_item_weighted_overlap: f32,
+ pub containing_range_vs_signature_weighted_overlap: f32,
+ pub adjacent_vs_item_weighted_overlap: f32,
+ pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+/*
+#[derive(Debug, Clone)]
+pub struct SerializedJson<T> {
+ raw: Box<RawValue>,
+ _phantom: PhantomData<T>,
+}
+
+impl<T> SerializedJson<T>
+where
+ T: Serialize + for<'de> Deserialize<'de>,
+{
+ pub fn new(value: &T) -> Result<Self, serde_json::Error> {
+ Ok(SerializedJson {
+ raw: serde_json::value::to_raw_value(value)?,
+ _phantom: PhantomData,
+ })
+ }
+
+ pub fn deserialize(&self) -> Result<T, serde_json::Error> {
+ serde_json::from_str(self.raw.get())
+ }
+
+ pub fn as_raw(&self) -> &RawValue {
+ &self.raw
+ }
+
+ pub fn into_raw(self) -> Box<RawValue> {
+ self.raw
+ }
+}
+
+impl<T> Serialize for SerializedJson<T> {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ self.raw.serialize(serializer)
+ }
+}
+
+impl<'de, T> Deserialize<'de> for SerializedJson<T> {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let raw = Box::<RawValue>::deserialize(deserializer)?;
+ Ok(SerializedJson {
+ raw,
+ _phantom: PhantomData,
+ })
+ }
+}
+*/
@@ -14,6 +14,7 @@ path = "src/edit_prediction_context.rs"
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
+cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -41,6 +41,20 @@ impl Declaration {
}
}
+ pub fn parent(&self) -> Option<DeclarationId> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.parent,
+ Declaration::Buffer { declaration, .. } => declaration.parent,
+ }
+ }
+
+ pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
+ match self {
+ Declaration::File { .. } => None,
+ Declaration::Buffer { declaration, .. } => Some(declaration),
+ }
+ }
+
pub fn project_entry_id(&self) -> ProjectEntryId {
match self {
Declaration::File {
@@ -83,6 +97,16 @@ impl Declaration {
),
}
}
+
+ pub fn signature_range_in_item_text(&self) -> Range<usize> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(),
+ Declaration::Buffer { declaration, .. } => {
+ declaration.signature_range.start - declaration.item_range.start
+ ..declaration.signature_range.end - declaration.item_range.start
+ }
+ }
+ }
}
fn expand_range_to_line_boundaries_and_truncate(
@@ -1,10 +1,11 @@
+use cloud_llm_client::predict_edits_v3::ScoreComponents;
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
use serde::Serialize;
use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
-use text::{OffsetRangeExt, Point, ToPoint};
+use text::{Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
@@ -23,7 +24,7 @@ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
pub struct ScoredSnippet {
pub identifier: Identifier,
pub declaration: Declaration,
- pub score_components: ScoreInputs,
+ pub score_components: ScoreComponents,
pub scores: Scores,
}
@@ -90,8 +91,8 @@ pub fn scored_snippets(
let declaration_count = declarations.len();
declarations
- .iter()
- .filter_map(|declaration| match declaration {
+ .into_iter()
+ .filter_map(|(declaration_id, declaration)| match declaration {
Declaration::Buffer {
buffer_id,
declaration: buffer_declaration,
@@ -100,24 +101,29 @@ pub fn scored_snippets(
let is_same_file = buffer_id == ¤t_buffer.remote_id();
if is_same_file {
- range_intersection(
- &buffer_declaration.item_range.to_offset(¤t_buffer),
- &excerpt.range,
- )
- .is_none()
- .then(|| {
+ let overlaps_excerpt =
+ range_intersection(&buffer_declaration.item_range, &excerpt.range)
+ .is_some();
+ if overlaps_excerpt
+ || excerpt
+ .parent_declarations
+ .iter()
+ .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
+ {
+ None
+ } else {
let declaration_line = buffer_declaration
.item_range
.start
.to_point(current_buffer)
.row;
- (
+ Some((
true,
(cursor_point.row as i32 - declaration_line as i32)
.unsigned_abs(),
declaration,
- )
- })
+ ))
+ }
} else {
Some((false, u32::MAX, declaration))
}
@@ -238,7 +244,7 @@ fn score_snippet(
let adjacent_vs_signature_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
- let score_components = ScoreInputs {
+ let score_components = ScoreComponents {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
@@ -261,51 +267,30 @@ fn score_snippet(
Some(ScoredSnippet {
identifier: identifier.clone(),
declaration: declaration,
- scores: score_components.score(),
+ scores: Scores::score(&score_components),
score_components,
})
}
-#[derive(Clone, Debug, Serialize)]
-pub struct ScoreInputs {
- pub is_same_file: bool,
- pub is_referenced_nearby: bool,
- pub is_referenced_in_breadcrumb: bool,
- pub reference_count: usize,
- pub same_file_declaration_count: usize,
- pub declaration_count: usize,
- pub reference_line_distance: u32,
- pub declaration_line_distance: u32,
- pub declaration_line_distance_rank: usize,
- pub containing_range_vs_item_jaccard: f32,
- pub containing_range_vs_signature_jaccard: f32,
- pub adjacent_vs_item_jaccard: f32,
- pub adjacent_vs_signature_jaccard: f32,
- pub containing_range_vs_item_weighted_overlap: f32,
- pub containing_range_vs_signature_weighted_overlap: f32,
- pub adjacent_vs_item_weighted_overlap: f32,
- pub adjacent_vs_signature_weighted_overlap: f32,
-}
-
#[derive(Clone, Debug, Serialize)]
pub struct Scores {
pub signature: f32,
pub declaration: f32,
}
-impl ScoreInputs {
- fn score(&self) -> Scores {
+impl Scores {
+ fn score(components: &ScoreComponents) -> Scores {
// Score related to how likely this is the correct declaration, range 0 to 1
- let accuracy_score = if self.is_same_file {
+ let accuracy_score = if components.is_same_file {
// TODO: use declaration_line_distance_rank
- 1.0 / self.same_file_declaration_count as f32
+ 1.0 / components.same_file_declaration_count as f32
} else {
- 1.0 / self.declaration_count as f32
+ 1.0 / components.declaration_count as f32
};
// Score related to the distance between the reference and cursor, range 0 to 1
- let distance_score = if self.is_referenced_nearby {
- 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+ let distance_score = if components.is_referenced_nearby {
+ 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
} else {
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
0.5
@@ -315,10 +300,12 @@ impl ScoreInputs {
let combined_score = 10.0 * accuracy_score * distance_score;
Scores {
- signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+ signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multiplied by 2 and by there being more
// weighted overlap.
- declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+ declaration: 2.0
+ * combined_score
+ * components.containing_range_vs_item_weighted_overlap,
}
}
}
@@ -6,8 +6,8 @@ mod reference;
mod syntax_index;
mod text_similarity;
-use std::time::Instant;
-
+use cloud_llm_client::predict_edits_v3::{self, Signature};
+use collections::HashMap;
pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
pub use declaration_scoring::SnippetStyle;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
@@ -18,14 +18,17 @@ pub use reference::references_in_excerpt;
pub use syntax_index::SyntaxIndex;
use text::{Point, ToOffset as _};
-use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
+use crate::{
+ declaration::DeclarationId,
+ declaration_scoring::{ScoredSnippet, scored_snippets},
+ syntax_index::SyntaxIndexState,
+};
#[derive(Debug)]
pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
pub snippets: Vec<ScoredSnippet>,
- pub retrieval_duration: std::time::Duration,
}
impl EditPredictionContext {
@@ -36,34 +39,135 @@ impl EditPredictionContext {
syntax_index: Entity<SyntaxIndex>,
cx: &mut App,
) -> Task<Option<Self>> {
- let start = Instant::now();
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
cx.background_spawn(async move {
let index_state = index_state.lock().await;
+ Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
+ })
+ }
- let excerpt =
- EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
- let excerpt_text = excerpt.text(&buffer);
- let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
- let cursor_offset = cursor_point.to_offset(&buffer);
-
- let snippets = scored_snippets(
- &index_state,
- &excerpt,
- &excerpt_text,
- references,
- cursor_offset,
- &buffer,
- );
-
- Some(Self {
- excerpt,
- excerpt_text,
- snippets,
- retrieval_duration: start.elapsed(),
- })
+ fn gather_context(
+ cursor_point: Point,
+ buffer: BufferSnapshot,
+ excerpt_options: EditPredictionExcerptOptions,
+ index_state: &SyntaxIndexState,
+ ) -> Option<Self> {
+ let excerpt = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &buffer,
+ &excerpt_options,
+ Some(index_state),
+ )?;
+ let excerpt_text = excerpt.text(&buffer);
+ let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
+ let cursor_offset = cursor_point.to_offset(&buffer);
+
+ let snippets = scored_snippets(
+ &index_state,
+ &excerpt,
+ &excerpt_text,
+ references,
+ cursor_offset,
+ &buffer,
+ );
+
+ Some(Self {
+ excerpt,
+ excerpt_text,
+ snippets,
})
}
+
+ pub fn cloud_request(
+ cursor_point: Point,
+ buffer: BufferSnapshot,
+ excerpt_options: EditPredictionExcerptOptions,
+ syntax_index: Entity<SyntaxIndex>,
+ cx: &mut App,
+ ) -> Task<Option<predict_edits_v3::Body>> {
+ let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
+ cx.background_spawn(async move {
+ let index_state = index_state.lock().await;
+ Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
+ .map(|context| context.into_cloud_request(&index_state))
+ })
+ }
+
+ pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body {
+ let mut signatures = Vec::new();
+ let mut declaration_to_signature_index = HashMap::default();
+ let mut referenced_declarations = Vec::new();
+ let excerpt_parent = self
+ .excerpt
+ .parent_declarations
+ .last()
+ .and_then(|(parent, _)| {
+ add_signature(
+ *parent,
+ &mut declaration_to_signature_index,
+ &mut signatures,
+ index,
+ )
+ });
+ for snippet in self.snippets {
+ let parent_index = snippet.declaration.parent().and_then(|parent| {
+ add_signature(
+ parent,
+ &mut declaration_to_signature_index,
+ &mut signatures,
+ index,
+ )
+ });
+ let (text, text_is_truncated) = snippet.declaration.item_text();
+ referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
+ text: text.into(),
+ text_is_truncated,
+ signature_range: snippet.declaration.signature_range_in_item_text(),
+ parent_index,
+ score_components: snippet.score_components,
+ signature_score: snippet.scores.signature,
+ declaration_score: snippet.scores.declaration,
+ });
+ }
+ predict_edits_v3::Body {
+ excerpt: self.excerpt_text.body,
+ referenced_declarations,
+ signatures,
+ excerpt_parent,
+ // todo!
+ events: vec![],
+ can_collect_data: false,
+ diagnostic_groups: None,
+ git_info: None,
+ }
+ }
+}
+
+fn add_signature(
+ declaration_id: DeclarationId,
+ declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
+ signatures: &mut Vec<Signature>,
+ index: &SyntaxIndexState,
+) -> Option<usize> {
+ if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
+ return Some(*signature_index);
+ }
+ let Some(parent_declaration) = index.declaration(declaration_id) else {
+ log::error!("bug: missing parent declaration");
+ return None;
+ };
+ let parent_index = parent_declaration.parent().and_then(|parent| {
+ add_signature(parent, declaration_to_signature_index, signatures, index)
+ });
+ let (text, text_is_truncated) = parent_declaration.signature_text();
+ let signature_index = signatures.len();
+ signatures.push(Signature {
+ text: text.into(),
+ text_is_truncated,
+ parent_index,
+ });
+ declaration_to_signature_index.insert(declaration_id, signature_index);
+ Some(signature_index)
}
#[cfg(test)]
@@ -105,10 +209,9 @@ mod tests {
cursor_point,
buffer_snapshot,
EditPredictionExcerptOptions {
- max_bytes: 40,
+ max_bytes: 60,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
},
index,
cx,
@@ -117,8 +220,13 @@ mod tests {
.await
.unwrap();
- assert_eq!(context.snippets.len(), 1);
- assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
+ let mut snippet_identifiers = context
+ .snippets
+ .iter()
+ .map(|snippet| snippet.identifier.name.as_ref())
+ .collect::<Vec<_>>();
+ snippet_identifiers.sort();
+ assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
drop(buffer);
}
@@ -1,9 +1,11 @@
use language::BufferSnapshot;
use std::ops::Range;
-use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
+use text::{Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
+use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
+
// TODO:
//
// - Test parent signatures
@@ -27,14 +29,12 @@ pub struct EditPredictionExcerptOptions {
pub min_bytes: usize,
/// Target ratio of bytes before the cursor divided by total bytes in the window.
pub target_before_cursor_over_total_bytes: f32,
- /// Whether to include parent signatures
- pub include_parent_signatures: bool,
}
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
- pub parent_signature_ranges: Vec<Range<usize>>,
+ pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
pub size: usize,
}
@@ -50,9 +50,9 @@ impl EditPredictionExcerpt {
.text_for_range(self.range.clone())
.collect::<String>();
let parent_signatures = self
- .parent_signature_ranges
+ .parent_declarations
.iter()
- .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
+ .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
.collect();
EditPredictionExcerptText {
body,
@@ -62,8 +62,9 @@ impl EditPredictionExcerpt {
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
/// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
- /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
- /// of parent outline items.
+ /// cursor.
+ ///
+ /// When `index` is provided, the excerpt will include the signatures of parent outline items.
///
/// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
/// expansion.
@@ -73,6 +74,7 @@ impl EditPredictionExcerpt {
query_point: Point,
buffer: &BufferSnapshot,
options: &EditPredictionExcerptOptions,
+ syntax_index: Option<&SyntaxIndexState>,
) -> Option<Self> {
if buffer.len() <= options.max_bytes {
log::debug!(
@@ -90,17 +92,9 @@ impl EditPredictionExcerpt {
return None;
}
- // TODO: Don't compute text / annotation_range / skip converting to and from anchors.
- let outline_items = if options.include_parent_signatures {
- buffer
- .outline_items_containing(query_range.clone(), false, None)
- .into_iter()
- .flat_map(|item| {
- Some(ExcerptOutlineItem {
- item_range: item.range.to_offset(&buffer),
- signature_range: item.signature_range?.to_offset(&buffer),
- })
- })
+ let parent_declarations = if let Some(syntax_index) = syntax_index {
+ syntax_index
+ .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
.collect()
} else {
Vec::new()
@@ -109,7 +103,7 @@ impl EditPredictionExcerpt {
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
- outline_items: &outline_items,
+ parent_declarations: &parent_declarations,
buffer,
options,
};
@@ -132,15 +126,15 @@ impl EditPredictionExcerpt {
excerpt_selector.select_lines()
}
- fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
+ fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
let size = range.len()
- + parent_signature_ranges
+ + parent_declarations
.iter()
- .map(|r| r.len())
+ .map(|(_, range)| range.len())
.sum::<usize>();
Self {
range,
- parent_signature_ranges,
+ parent_declarations,
size,
}
}
@@ -150,20 +144,14 @@ impl EditPredictionExcerpt {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
}
- let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
- let mut size = new_range.len();
- for range in &self.parent_signature_ranges {
- if range.contains_inclusive(&new_range) {
+ let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
+ for (declaration_id, range) in &self.parent_declarations {
+ if !range.contains_inclusive(&new_range) {
break;
}
- parent_signature_ranges.push(range.clone());
- size += range.len();
- }
- Self {
- range: new_range,
- parent_signature_ranges,
- size,
+ parent_declarations.push((*declaration_id, range.clone()));
}
+ Self::new(new_range, parent_declarations)
}
fn parent_signatures_size(&self) -> usize {
@@ -174,16 +162,11 @@ impl EditPredictionExcerpt {
struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
- outline_items: &'a [ExcerptOutlineItem],
+ parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
}
-struct ExcerptOutlineItem {
- item_range: Range<usize>,
- signature_range: Range<usize>,
-}
-
impl<'a> ExcerptSelector<'a> {
/// Finds the largest node that is smaller than the window size and contains `query_range`.
fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
@@ -396,13 +379,13 @@ impl<'a> ExcerptSelector<'a> {
}
fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
- let parent_signature_ranges = self
- .outline_items
+ let parent_declarations = self
+ .parent_declarations
.iter()
- .filter(|item| item.item_range.contains_inclusive(&range))
- .map(|item| item.signature_range.clone())
+ .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
+ .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
.collect();
- EditPredictionExcerpt::new(range, parent_signature_ranges)
+ EditPredictionExcerpt::new(range, parent_declarations)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -493,8 +476,9 @@ mod tests {
let buffer = create_buffer(&text, cx);
let cursor_point = cursor.to_point(&buffer);
- let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
- .expect("Should select an excerpt");
+ let excerpt =
+ EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
+ .expect("Should select an excerpt");
pretty_assertions::assert_eq!(
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
generate_marked_text(&text, &[expected_excerpt], false)
@@ -517,7 +501,6 @@ fn main() {
max_bytes: 20,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -541,7 +524,6 @@ fn bar() {}"#;
max_bytes: 65,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -561,7 +543,6 @@ fn main() {
max_bytes: 50,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -583,7 +564,6 @@ fn main() {
max_bytes: 60,
min_bytes: 45,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -608,7 +588,6 @@ fn main() {
max_bytes: 120,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.6,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -33,8 +33,8 @@ pub fn references_in_excerpt(
snapshot,
);
- for (range, text) in excerpt
- .parent_signature_ranges
+ for ((_, range), text) in excerpt
+ .parent_declarations
.iter()
.zip(excerpt_text.parent_signatures.iter())
{
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use collections::{HashMap, HashSet};
use futures::lock::Mutex;
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
@@ -8,8 +6,11 @@ use project::buffer_store::{BufferStore, BufferStoreEvent};
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
use slotmap::SlotMap;
+use std::iter;
+use std::ops::Range;
+use std::sync::Arc;
use text::BufferId;
-use util::{debug_panic, some_or_debug_panic};
+use util::{RangeExt as _, debug_panic, some_or_debug_panic};
use crate::declaration::{
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
@@ -432,7 +433,7 @@ impl SyntaxIndexState {
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: &Identifier,
- ) -> Vec<Declaration> {
+ ) -> Vec<(DeclarationId, &Declaration)> {
// make sure to not have a large stack allocation
assert!(N < 32);
@@ -454,7 +455,7 @@ impl SyntaxIndexState {
project_entry_id, ..
} => {
included_buffer_entry_ids.push(*project_entry_id);
- result.push(declaration.clone());
+ result.push((*declaration_id, declaration));
if result.len() == N {
return Vec::new();
}
@@ -463,19 +464,19 @@ impl SyntaxIndexState {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
- file_declarations.push(declaration.clone());
+ file_declarations.push((*declaration_id, declaration));
}
}
}
}
- for declaration in file_declarations {
+ for (declaration_id, declaration) in file_declarations {
match declaration {
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
- result.push(declaration);
+ result.push((declaration_id, declaration));
if result.len() == N {
return Vec::new();
@@ -489,6 +490,35 @@ impl SyntaxIndexState {
result
}
+ pub fn buffer_declarations_containing_range(
+ &self,
+ buffer_id: BufferId,
+ range: Range<usize>,
+ ) -> impl Iterator<Item = (DeclarationId, &BufferDeclaration)> {
+ let Some(buffer_state) = self.buffers.get(&buffer_id) else {
+ return itertools::Either::Left(iter::empty());
+ };
+
+ let iter = buffer_state
+ .declarations
+ .iter()
+ .filter_map(move |declaration_id| {
+ let Some(declaration) = self
+ .declarations
+ .get(*declaration_id)
+ .and_then(|d| d.as_buffer())
+ else {
+ log::error!("bug: missing buffer outline declaration");
+ return None;
+ };
+ if declaration.item_range.contains_inclusive(&range) {
+ return Some((*declaration_id, declaration));
+ }
+ return None;
+ });
+ itertools::Either::Right(iter)
+ }
+
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
match declaration {
Declaration::File {
@@ -553,11 +583,11 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
- let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main.clone());
assert_eq!(decl.item_range_in_file, 32..280);
- let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
+ let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range_in_file, 0..98);
});
@@ -577,7 +607,7 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
- let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
@@ -618,7 +648,7 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
- let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
@@ -676,11 +706,11 @@ mod tests {
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
- let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
- expect_file_decl("a.rs", &decls[1], &project, cx);
+ expect_file_decl("a.rs", &decls[1].1, &project, cx);
});
}
@@ -695,8 +725,8 @@ mod tests {
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
- expect_file_decl("c.rs", &decls[0], &project, cx);
- expect_file_decl("a.rs", &decls[1], &project, cx);
+ expect_file_decl("c.rs", &decls[0].1, &project, cx);
+ expect_file_decl("a.rs", &decls[1].1, &project, cx);
});
}
@@ -4,7 +4,7 @@ use std::{
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
- time::Duration,
+ time::{Duration, Instant},
};
use collections::HashMap;
@@ -195,6 +195,8 @@ impl EditPredictionTools {
.timer(Duration::from_millis(50))
.await;
+ let mut start_time = None;
+
let Ok(task) = this.update(cx, |this, cx| {
fn number_input_value<T: FromStr + Default>(
input: &Entity<SingleLineInput>,
@@ -216,10 +218,10 @@ impl EditPredictionTools {
&this.cursor_context_ratio_input,
cx,
),
- // TODO Display and add to options
- include_parent_signatures: false,
};
+ start_time = Some(Instant::now());
+
EditPredictionContext::gather(
cursor_position,
current_buffer_snapshot,
@@ -243,6 +245,7 @@ impl EditPredictionTools {
.ok();
return;
};
+ let retrieval_duration = start_time.unwrap().elapsed();
let mut languages = HashMap::default();
for snippet in context.snippets.iter() {
@@ -320,7 +323,7 @@ impl EditPredictionTools {
this.last_context = Some(ContextState {
context_editor,
- retrieval_duration: context.retrieval_duration,
+ retrieval_duration,
});
cx.notify();
})