diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index b9b356490048138421c3c5f40940a836348fb81b..2faa2a2092ac141d514a9c695d58adf8d44614ba 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -33,7 +33,7 @@ use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ - cmp, + cmp::{self, Ordering}, fmt::Debug, iter, mem, ops::Range, @@ -618,6 +618,7 @@ pub struct Context { telemetry: Option>, language_registry: Arc, workflow_steps: Vec, + edits_since_last_workflow_step_prune: language::Subscription, project: Option>, prompt_builder: Arc, } @@ -667,6 +668,8 @@ impl Context { }); let edits_since_last_slash_command_parse = buffer.update(cx, |buffer, _| buffer.subscribe()); + let edits_since_last_workflow_step_prune = + buffer.update(cx, |buffer, _| buffer.subscribe()); let mut this = Self { id, timestamp: clock::Lamport::new(replica_id), @@ -693,6 +696,7 @@ impl Context { project, language_registry, workflow_steps: Vec::new(), + edits_since_last_workflow_step_prune, prompt_builder, }; @@ -1058,7 +1062,9 @@ impl Context { language::Event::Edited => { self.count_remaining_tokens(cx); self.reparse_slash_commands(cx); - self.prune_invalid_workflow_steps(cx); + // Use `inclusive = true` to invalidate a step when an edit occurs + // at the start/end of a parsed step. + self.prune_invalid_workflow_steps(true, cx); cx.emit(ContextEvent::MessagesEdited); } _ => {} @@ -1165,24 +1171,62 @@ impl Context { } } - fn prune_invalid_workflow_steps(&mut self, cx: &mut ModelContext) { - let buffer = self.buffer.read(cx); - let prev_len = self.workflow_steps.len(); + fn prune_invalid_workflow_steps(&mut self, inclusive: bool, cx: &mut ModelContext) { let mut removed = Vec::new(); - self.workflow_steps.retain(|step| { - if step.tagged_range.start.is_valid(buffer) && step.tagged_range.end.is_valid(buffer) { - true - } else { - removed.push(step.tagged_range.clone()); - false - } - }); - if self.workflow_steps.len() != prev_len { + + for edit_range in self.edits_since_last_workflow_step_prune.consume() { + let intersecting_range = self.find_intersecting_steps(edit_range.new, inclusive, cx); + removed.extend( + self.workflow_steps + .drain(intersecting_range) + .map(|step| step.tagged_range), + ); + } + + if !removed.is_empty() { cx.emit(ContextEvent::WorkflowStepsRemoved(removed)); cx.notify(); } } + fn find_intersecting_steps( + &self, + range: Range, + inclusive: bool, + cx: &AppContext, + ) -> Range { + let buffer = self.buffer.read(cx); + let start_ix = match self.workflow_steps.binary_search_by(|probe| { + probe + .tagged_range + .end + .to_offset(buffer) + .cmp(&range.start) + .then(if inclusive { + Ordering::Greater + } else { + Ordering::Less + }) + }) { + Ok(ix) | Err(ix) => ix, + }; + let end_ix = match self.workflow_steps.binary_search_by(|probe| { + probe + .tagged_range + .start + .to_offset(buffer) + .cmp(&range.end) + .then(if inclusive { + Ordering::Less + } else { + Ordering::Greater + }) + }) { + Ok(ix) | Err(ix) => ix, + }; + start_ix..end_ix + } + fn parse_workflow_steps_in_range( &mut self, range: Range, @@ -1248,8 +1292,12 @@ impl Context { self.workflow_steps.insert(index, step); self.resolve_workflow_step(step_range, project.clone(), cx); } + + // Delete tags, making sure we don't accidentally invalidate + // the step we just parsed. self.buffer .update(cx, |buffer, cx| buffer.edit(edits, None, cx)); + self.edits_since_last_workflow_step_prune.consume(); } pub fn resolve_workflow_step( @@ -1629,6 +1677,8 @@ impl Context { message_start_offset..message_new_end_offset }); if let Some(project) = this.project.clone() { + // Use `inclusive = false` as edits might occur at the end of a parsed step. + this.prune_invalid_workflow_steps(false, cx); this.parse_workflow_steps_in_range(message_range, project, cx); } cx.emit(ContextEvent::StreamedCompletion); diff --git a/crates/language/src/outline.rs b/crates/language/src/outline.rs index b7e4a7e14dba69f1232f5acfe3d9b597dbebb016..7bcbcc6bc32ce51a70cd768e7b0b5d7d9b307c95 100644 --- a/crates/language/src/outline.rs +++ b/crates/language/src/outline.rs @@ -84,13 +84,24 @@ impl Outline { } } - /// Find the most similar symbol to the provided query according to the Jaro-Winkler distance measure. + /// Find the most similar symbol to the provided query using normalized Levenshtein distance. pub fn find_most_similar(&self, query: &str) -> Option<&OutlineItem> { - let candidate = self.path_candidates.iter().max_by(|a, b| { - strsim::jaro_winkler(&a.string, query) - .total_cmp(&strsim::jaro_winkler(&b.string, query)) - })?; - Some(&self.items[candidate.id]) + const SIMILARITY_THRESHOLD: f64 = 0.6; + + let (item, similarity) = self + .items + .iter() + .map(|item| { + let similarity = strsim::normalized_levenshtein(&item.text, query); + (item, similarity) + }) + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())?; + + if similarity >= SIMILARITY_THRESHOLD { + Some(item) + } else { + None + } } /// Find all outline symbols according to a longest subsequence match with the query, ordered descending by match score. @@ -208,3 +219,46 @@ pub fn render_item( StyledText::new(outline_item.text.clone()).with_highlights(&text_style, highlights) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_most_similar_with_low_similarity() { + let outline = Outline::new(vec![ + OutlineItem { + depth: 0, + range: Point::new(0, 0)..Point::new(5, 0), + text: "fn process".to_string(), + highlight_ranges: vec![], + name_ranges: vec![3..10], + body_range: None, + annotation_range: None, + }, + OutlineItem { + depth: 0, + range: Point::new(7, 0)..Point::new(12, 0), + text: "struct DataProcessor".to_string(), + highlight_ranges: vec![], + name_ranges: vec![7..20], + body_range: None, + annotation_range: None, + }, + ]); + assert_eq!( + outline.find_most_similar("pub fn process"), + Some(&outline.items[0]) + ); + assert_eq!( + outline.find_most_similar("async fn process"), + Some(&outline.items[0]) + ); + assert_eq!( + outline.find_most_similar("struct Processor"), + Some(&outline.items[1]) + ); + assert_eq!(outline.find_most_similar("struct User"), None); + assert_eq!(outline.find_most_similar("struct"), None); + } +}