diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 88ca5aad35389511bd6b47659b0ee407bbee5632..7b349edcd489f55fc0b834ad437f02f7fbde63d3 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -3,6 +3,7 @@ pub mod assistant_panel; pub mod assistant_settings; mod codegen; mod completion_provider; +mod omit_ranges; mod prompts; mod saved_conversation; mod search; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index cc87307745116c5e751c2cc8ed8576d616111624..3e60affc95e098cbe8173851e796e7bbf6b06088 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -5,6 +5,7 @@ use crate::{ ambient_context::*, assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel}, codegen::{self, Codegen, CodegenKind}, + omit_ranges::text_in_range_omitting_ranges, prompts::prompt::generate_content_prompt, search::*, slash_command::{ @@ -3556,38 +3557,15 @@ pub struct Message { impl Message { fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage { - let mut slash_command_ranges = self.slash_command_ranges.iter().peekable(); - let mut content = String::with_capacity(self.offset_range.len()); - let mut offset = self.offset_range.start; - let mut chunks = buffer.text_for_range(self.offset_range.clone()); - while let Some(chunk) = chunks.next() { - if let Some(slash_command_range) = slash_command_ranges.peek() { - match offset.cmp(&slash_command_range.start) { - Ordering::Less => { - let max_len = slash_command_range.start - offset; - if chunk.len() < max_len { - content.push_str(chunk); - offset += chunk.len(); - } else { - content.push_str(&chunk[..max_len]); - offset += max_len; - chunks.seek(slash_command_range.end); - slash_command_ranges.next(); - } - } - Ordering::Equal | Ordering::Greater => { - chunks.seek(slash_command_range.end); - offset = slash_command_range.end; - slash_command_ranges.next(); - } - } - } else { - content.push_str(chunk); - } - } + let mut content = text_in_range_omitting_ranges( + buffer.as_rope(), + self.offset_range.clone(), + &self.slash_command_ranges, + ); + content.truncate(content.trim_end().len()); LanguageModelRequestMessage { role: self.role, - content: content.trim_end().into(), + content, } } } diff --git a/crates/assistant/src/omit_ranges.rs b/crates/assistant/src/omit_ranges.rs new file mode 100644 index 0000000000000000000000000000000000000000..f4a6988e95628813b078f01258d2a700a370fcbb --- /dev/null +++ b/crates/assistant/src/omit_ranges.rs @@ -0,0 +1,101 @@ +use rope::Rope; +use std::{cmp::Ordering, ops::Range}; + +pub(crate) fn text_in_range_omitting_ranges( + rope: &Rope, + range: Range, + omit_ranges: &[Range], +) -> String { + let mut content = String::with_capacity(range.len()); + let mut omit_ranges = omit_ranges + .iter() + .skip_while(|omit_range| omit_range.end <= range.start) + .peekable(); + let mut offset = range.start; + let mut chunks = rope.chunks_in_range(range.clone()); + while let Some(chunk) = chunks.next() { + if let Some(omit_range) = omit_ranges.peek() { + match offset.cmp(&omit_range.start) { + Ordering::Less => { + let max_len = omit_range.start - offset; + if chunk.len() < max_len { + content.push_str(chunk); + offset += chunk.len(); + } else { + content.push_str(&chunk[..max_len]); + chunks.seek(omit_range.end.min(range.end)); + offset = omit_range.end; + omit_ranges.next(); + } + } + Ordering::Equal | Ordering::Greater => { + chunks.seek(omit_range.end.min(range.end)); + offset = omit_range.end; + omit_ranges.next(); + } + } + } else { + content.push_str(chunk); + offset += chunk.len(); + } + } + + content +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{rngs::StdRng, Rng as _}; + use util::RandomCharIter; + + #[gpui::test(iterations = 100)] + fn test_text_in_range_omitting_ranges(mut rng: StdRng) { + let text = RandomCharIter::new(&mut rng).take(1024).collect::(); + let rope = Rope::from(text.as_str()); + + let mut start = rng.gen_range(0..=text.len() / 2); + let mut end = rng.gen_range(text.len() / 2..=text.len()); + while !text.is_char_boundary(start) { + start -= 1; + } + while !text.is_char_boundary(end) { + end += 1; + } + let range = start..end; + + let mut ix = 0; + let mut omit_ranges = Vec::new(); + for _ in 0..rng.gen_range(0..10) { + let mut start = rng.gen_range(ix..=text.len()); + while !text.is_char_boundary(start) { + start += 1; + } + let mut end = rng.gen_range(start..=text.len()); + while !text.is_char_boundary(end) { + end += 1; + } + omit_ranges.push(start..end); + ix = end; + if ix == text.len() { + break; + } + } + + let mut expected_text = text[range.clone()].to_string(); + for omit_range in omit_ranges.iter().rev() { + let start = omit_range + .start + .saturating_sub(range.start) + .min(range.len()); + let end = omit_range.end.saturating_sub(range.start).min(range.len()); + expected_text.replace_range(start..end, ""); + } + + assert_eq!( + text_in_range_omitting_ranges(&rope, range.clone(), &omit_ranges), + expected_text, + "text: {text:?}\nrange: {range:?}\nomit_ranges: {omit_ranges:?}" + ); + } +}