From 48f61936284f3924b9eeece13af21215b2ba228e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 12 Aug 2024 11:09:07 +0200 Subject: [PATCH] Improve workflow step pruning and symbol similarity matching (#16036) This PR improves workflow step management and symbol matching. We've optimized step pruning to remove any step that intersects an edit and switched to normalized Levenshtein distance for more accurate symbol matching. Release Notes: - N/A --- crates/assistant/src/context.rs | 78 +++++++++++++++++++++++++++------ crates/language/src/outline.rs | 66 +++++++++++++++++++++++++--- 2 files changed, 124 insertions(+), 20 deletions(-) diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index b9b356490048138421c3c5f40940a836348fb81b..2faa2a2092ac141d514a9c695d58adf8d44614ba 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -33,7 +33,7 @@ use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ - cmp, + cmp::{self, Ordering}, fmt::Debug, iter, mem, ops::Range, @@ -618,6 +618,7 @@ pub struct Context { telemetry: Option>, language_registry: Arc, workflow_steps: Vec, + edits_since_last_workflow_step_prune: language::Subscription, project: Option>, prompt_builder: Arc, } @@ -667,6 +668,8 @@ impl Context { }); let edits_since_last_slash_command_parse = buffer.update(cx, |buffer, _| buffer.subscribe()); + let edits_since_last_workflow_step_prune = + buffer.update(cx, |buffer, _| buffer.subscribe()); let mut this = Self { id, timestamp: clock::Lamport::new(replica_id), @@ -693,6 +696,7 @@ impl Context { project, language_registry, workflow_steps: Vec::new(), + edits_since_last_workflow_step_prune, prompt_builder, }; @@ -1058,7 +1062,9 @@ impl Context { language::Event::Edited => { self.count_remaining_tokens(cx); self.reparse_slash_commands(cx); - self.prune_invalid_workflow_steps(cx); + // Use `inclusive = true` to invalidate a step when an edit occurs + // at the start/end of a parsed step. + self.prune_invalid_workflow_steps(true, cx); cx.emit(ContextEvent::MessagesEdited); } _ => {} @@ -1165,24 +1171,62 @@ impl Context { } } - fn prune_invalid_workflow_steps(&mut self, cx: &mut ModelContext) { - let buffer = self.buffer.read(cx); - let prev_len = self.workflow_steps.len(); + fn prune_invalid_workflow_steps(&mut self, inclusive: bool, cx: &mut ModelContext) { let mut removed = Vec::new(); - self.workflow_steps.retain(|step| { - if step.tagged_range.start.is_valid(buffer) && step.tagged_range.end.is_valid(buffer) { - true - } else { - removed.push(step.tagged_range.clone()); - false - } - }); - if self.workflow_steps.len() != prev_len { + + for edit_range in self.edits_since_last_workflow_step_prune.consume() { + let intersecting_range = self.find_intersecting_steps(edit_range.new, inclusive, cx); + removed.extend( + self.workflow_steps + .drain(intersecting_range) + .map(|step| step.tagged_range), + ); + } + + if !removed.is_empty() { cx.emit(ContextEvent::WorkflowStepsRemoved(removed)); cx.notify(); } } + fn find_intersecting_steps( + &self, + range: Range, + inclusive: bool, + cx: &AppContext, + ) -> Range { + let buffer = self.buffer.read(cx); + let start_ix = match self.workflow_steps.binary_search_by(|probe| { + probe + .tagged_range + .end + .to_offset(buffer) + .cmp(&range.start) + .then(if inclusive { + Ordering::Greater + } else { + Ordering::Less + }) + }) { + Ok(ix) | Err(ix) => ix, + }; + let end_ix = match self.workflow_steps.binary_search_by(|probe| { + probe + .tagged_range + .start + .to_offset(buffer) + .cmp(&range.end) + .then(if inclusive { + Ordering::Less + } else { + Ordering::Greater + }) + }) { + Ok(ix) | Err(ix) => ix, + }; + start_ix..end_ix + } + fn parse_workflow_steps_in_range( &mut self, range: Range, @@ -1248,8 +1292,12 @@ impl Context { self.workflow_steps.insert(index, step); self.resolve_workflow_step(step_range, project.clone(), cx); } + + // Delete tags, making sure we don't accidentally invalidate + // the step we just parsed. self.buffer .update(cx, |buffer, cx| buffer.edit(edits, None, cx)); + self.edits_since_last_workflow_step_prune.consume(); } pub fn resolve_workflow_step( @@ -1629,6 +1677,8 @@ impl Context { message_start_offset..message_new_end_offset }); if let Some(project) = this.project.clone() { + // Use `inclusive = false` as edits might occur at the end of a parsed step. + this.prune_invalid_workflow_steps(false, cx); this.parse_workflow_steps_in_range(message_range, project, cx); } cx.emit(ContextEvent::StreamedCompletion); diff --git a/crates/language/src/outline.rs b/crates/language/src/outline.rs index b7e4a7e14dba69f1232f5acfe3d9b597dbebb016..7bcbcc6bc32ce51a70cd768e7b0b5d7d9b307c95 100644 --- a/crates/language/src/outline.rs +++ b/crates/language/src/outline.rs @@ -84,13 +84,24 @@ impl Outline { } } - /// Find the most similar symbol to the provided query according to the Jaro-Winkler distance measure. + /// Find the most similar symbol to the provided query using normalized Levenshtein distance. pub fn find_most_similar(&self, query: &str) -> Option<&OutlineItem> { - let candidate = self.path_candidates.iter().max_by(|a, b| { - strsim::jaro_winkler(&a.string, query) - .total_cmp(&strsim::jaro_winkler(&b.string, query)) - })?; - Some(&self.items[candidate.id]) + const SIMILARITY_THRESHOLD: f64 = 0.6; + + let (item, similarity) = self + .items + .iter() + .map(|item| { + let similarity = strsim::normalized_levenshtein(&item.text, query); + (item, similarity) + }) + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())?; + + if similarity >= SIMILARITY_THRESHOLD { + Some(item) + } else { + None + } } /// Find all outline symbols according to a longest subsequence match with the query, ordered descending by match score. @@ -208,3 +219,46 @@ pub fn render_item( StyledText::new(outline_item.text.clone()).with_highlights(&text_style, highlights) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_most_similar_with_low_similarity() { + let outline = Outline::new(vec![ + OutlineItem { + depth: 0, + range: Point::new(0, 0)..Point::new(5, 0), + text: "fn process".to_string(), + highlight_ranges: vec![], + name_ranges: vec![3..10], + body_range: None, + annotation_range: None, + }, + OutlineItem { + depth: 0, + range: Point::new(7, 0)..Point::new(12, 0), + text: "struct DataProcessor".to_string(), + highlight_ranges: vec![], + name_ranges: vec![7..20], + body_range: None, + annotation_range: None, + }, + ]); + assert_eq!( + outline.find_most_similar("pub fn process"), + Some(&outline.items[0]) + ); + assert_eq!( + outline.find_most_similar("async fn process"), + Some(&outline.items[0]) + ); + assert_eq!( + outline.find_most_similar("struct Processor"), + Some(&outline.items[1]) + ); + assert_eq!(outline.find_most_similar("struct User"), None); + assert_eq!(outline.find_most_similar("struct"), None); + } +}