zeta1.rs

  1use std::{fmt::Write, ops::Range, sync::Arc};
  2
  3use crate::cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count};
  4use anyhow::Result;
  5use cloud_llm_client::PredictEditsBody;
  6use edit_prediction_types::PredictedCursorPosition;
  7use language::{Anchor, BufferSnapshot, Point, text_diff};
  8use text::Bias;
  9use zeta_prompt::{
 10    Event,
 11    zeta1::{
 12        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 13        START_OF_FILE_MARKER,
 14    },
 15};
 16
 17pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 18pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 19
 20pub(crate) fn parse_edits(
 21    output_excerpt: &str,
 22    editable_range: Range<usize>,
 23    snapshot: &BufferSnapshot,
 24) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
 25    let content = output_excerpt.replace(CURSOR_MARKER, "");
 26
 27    let start_markers = content
 28        .match_indices(EDITABLE_REGION_START_MARKER)
 29        .collect::<Vec<_>>();
 30    anyhow::ensure!(
 31        start_markers.len() <= 1,
 32        "expected at most one start marker, found {}",
 33        start_markers.len()
 34    );
 35
 36    let end_markers = content
 37        .match_indices(EDITABLE_REGION_END_MARKER)
 38        .collect::<Vec<_>>();
 39    anyhow::ensure!(
 40        end_markers.len() <= 1,
 41        "expected at most one end marker, found {}",
 42        end_markers.len()
 43    );
 44
 45    let sof_markers = content
 46        .match_indices(START_OF_FILE_MARKER)
 47        .collect::<Vec<_>>();
 48    anyhow::ensure!(
 49        sof_markers.len() <= 1,
 50        "expected at most one start-of-file marker, found {}",
 51        sof_markers.len()
 52    );
 53
 54    let content_start = start_markers
 55        .first()
 56        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len())
 57        .map(|start| {
 58            if content.len() > start
 59                && content.is_char_boundary(start)
 60                && content[start..].starts_with('\n')
 61            {
 62                start + 1
 63            } else {
 64                start
 65            }
 66        })
 67        .unwrap_or(0);
 68    let content_end = end_markers
 69        .first()
 70        .map(|e| {
 71            if e.0 > 0 && content.is_char_boundary(e.0 - 1) && content[e.0 - 1..].starts_with('\n')
 72            {
 73                e.0 - 1
 74            } else {
 75                e.0
 76            }
 77        })
 78        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
 79
 80    // min to account for content_end and content_start both accounting for the same newline in the following case:
 81    // <|editable_region_start|>\n<|editable_region_end|>
 82    let new_text = &content[content_start.min(content_end)..content_end];
 83
 84    let old_text = snapshot
 85        .text_for_range(editable_range.clone())
 86        .collect::<String>();
 87
 88    Ok(compute_edits(
 89        old_text,
 90        new_text,
 91        editable_range.start,
 92        snapshot,
 93    ))
 94}
 95
 96pub fn compute_edits(
 97    old_text: String,
 98    new_text: &str,
 99    offset: usize,
100    snapshot: &BufferSnapshot,
101) -> Vec<(Range<Anchor>, Arc<str>)> {
102    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
103}
104
105pub fn compute_edits_and_cursor_position(
106    old_text: String,
107    new_text: &str,
108    offset: usize,
109    cursor_offset_in_new_text: Option<usize>,
110    snapshot: &BufferSnapshot,
111) -> (
112    Vec<(Range<Anchor>, Arc<str>)>,
113    Option<PredictedCursorPosition>,
114) {
115    let diffs = text_diff(&old_text, new_text);
116
117    // Delta represents the cumulative change in byte count from all preceding edits.
118    // new_offset = old_offset + delta, so old_offset = new_offset - delta
119    let mut delta: isize = 0;
120    let mut cursor_position: Option<PredictedCursorPosition> = None;
121    let buffer_len = snapshot.len();
122
123    let edits = diffs
124        .iter()
125        .map(|(raw_old_range, new_text)| {
126            // Compute cursor position if it falls within or before this edit.
127            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
128                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
129                let edit_end_in_new = edit_start_in_new + new_text.len();
130
131                if cursor_offset < edit_start_in_new {
132                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
133                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
134                    cursor_position = Some(PredictedCursorPosition::at_anchor(
135                        snapshot.anchor_after(buffer_offset),
136                    ));
137                } else if cursor_offset < edit_end_in_new {
138                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
139                    let offset_within_insertion = cursor_offset - edit_start_in_new;
140                    cursor_position = Some(PredictedCursorPosition::new(
141                        snapshot.anchor_before(buffer_offset),
142                        offset_within_insertion,
143                    ));
144                }
145
146                delta += new_text.len() as isize - raw_old_range.len() as isize;
147            }
148
149            // Compute the edit with prefix/suffix trimming.
150            let mut old_range = raw_old_range.clone();
151            let old_slice = &old_text[old_range.clone()];
152
153            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
154            let suffix_len = common_prefix(
155                old_slice[prefix_len..].chars().rev(),
156                new_text[prefix_len..].chars().rev(),
157            );
158
159            old_range.start += offset;
160            old_range.end += offset;
161            old_range.start += prefix_len;
162            old_range.end -= suffix_len;
163
164            old_range.start = old_range.start.min(buffer_len);
165            old_range.end = old_range.end.min(buffer_len);
166
167            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
168            let range = if old_range.is_empty() {
169                let anchor = snapshot.anchor_after(old_range.start);
170                anchor..anchor
171            } else {
172                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
173            };
174            (range, new_text)
175        })
176        .collect();
177
178    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
179        let cursor_in_old = (cursor_offset as isize - delta) as usize;
180        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
181        cursor_position = Some(PredictedCursorPosition::at_anchor(
182            snapshot.anchor_after(buffer_offset),
183        ));
184    }
185
186    (edits, cursor_position)
187}
188
189fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
190    a.zip(b)
191        .take_while(|(a, b)| a == b)
192        .map(|(a, _)| a.len_utf8())
193        .sum()
194}
195
196pub struct GatherContextOutput {
197    pub body: PredictEditsBody,
198    pub context_range: Range<Point>,
199    pub editable_range: Range<usize>,
200    pub included_events_count: usize,
201}
202
203pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
204    prompt_for_events_impl(events, max_tokens).0
205}
206
207fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
208    let mut result = String::new();
209    for (ix, event) in events.iter().rev().enumerate() {
210        let event_string = format_event(event.as_ref());
211        let event_tokens = guess_token_count(event_string.len());
212        if event_tokens > remaining_tokens {
213            return (result, ix);
214        }
215
216        if !result.is_empty() {
217            result.insert_str(0, "\n\n");
218        }
219        result.insert_str(0, &event_string);
220        remaining_tokens -= event_tokens;
221    }
222    return (result, events.len());
223}
224
225pub fn format_event(event: &Event) -> String {
226    match event {
227        Event::BufferChange {
228            path,
229            old_path,
230            diff,
231            ..
232        } => {
233            let mut prompt = String::new();
234
235            if old_path != path {
236                writeln!(
237                    prompt,
238                    "User renamed {} to {}\n",
239                    old_path.display(),
240                    path.display()
241                )
242                .unwrap();
243            }
244
245            if !diff.is_empty() {
246                write!(
247                    prompt,
248                    "User edited {}:\n```diff\n{}\n```",
249                    path.display(),
250                    diff
251                )
252                .unwrap();
253            }
254
255            prompt
256        }
257    }
258}
259
260#[derive(Debug)]
261pub struct InputExcerpt {
262    pub context_range: Range<Point>,
263    pub editable_range: Range<Point>,
264    pub prompt: String,
265}
266
267pub fn excerpt_for_cursor_position(
268    position: Point,
269    path: &str,
270    snapshot: &BufferSnapshot,
271    editable_region_token_limit: usize,
272    context_token_limit: usize,
273) -> InputExcerpt {
274    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
275        position,
276        snapshot,
277        editable_region_token_limit,
278        context_token_limit,
279    );
280
281    let mut prompt = String::new();
282
283    writeln!(&mut prompt, "```{path}").unwrap();
284    if context_range.start == Point::zero() {
285        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
286    }
287
288    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
289        prompt.push_str(chunk.text);
290    }
291
292    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
293
294    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
295        prompt.push_str(chunk.text);
296    }
297    write!(prompt, "\n```").unwrap();
298
299    InputExcerpt {
300        context_range,
301        editable_range,
302        prompt,
303    }
304}
305
306fn push_editable_range(
307    cursor_position: Point,
308    snapshot: &BufferSnapshot,
309    editable_range: Range<Point>,
310    prompt: &mut String,
311) {
312    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
313    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
314        prompt.push_str(chunk.text);
315    }
316    prompt.push_str(CURSOR_MARKER);
317    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
318        prompt.push_str(chunk.text);
319    }
320    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use gpui::{App, AppContext};
327    use indoc::indoc;
328    use language::Buffer;
329    use text::OffsetRangeExt as _;
330
331    #[gpui::test]
332    fn test_excerpt_for_cursor_position(cx: &mut App) {
333        let text = indoc! {r#"
334            fn foo() {
335                let x = 42;
336                println!("Hello, world!");
337            }
338
339            fn bar() {
340                let x = 42;
341                let mut sum = 0;
342                for i in 0..x {
343                    sum += i;
344                }
345                println!("Sum: {}", sum);
346                return sum;
347            }
348
349            fn generate_random_numbers() -> Vec<i32> {
350                let mut rng = rand::thread_rng();
351                let mut numbers = Vec::new();
352                for _ in 0..5 {
353                    numbers.push(rng.random_range(1..101));
354                }
355                numbers
356            }
357        "#};
358        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
359        let snapshot = buffer.read(cx).snapshot();
360
361        // The excerpt expands to syntax boundaries.
362        // With 50 token editable limit, we get a region that expands to syntax nodes.
363        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
364        assert_eq!(
365            excerpt.prompt,
366            indoc! {r#"
367            ```main.rs
368
369            fn bar() {
370                let x = 42;
371            <|editable_region_start|>
372                let mut sum = 0;
373                for i in 0..x {
374                    sum += i;
375                }
376                println!("Sum: {}", sum);
377                r<|user_cursor_is_here|>eturn sum;
378            }
379
380            fn generate_random_numbers() -> Vec<i32> {
381            <|editable_region_end|>
382                let mut rng = rand::thread_rng();
383                let mut numbers = Vec::new();
384            ```"#}
385        );
386
387        // With smaller budget, the region expands to syntax boundaries but is tighter.
388        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
389        assert_eq!(
390            excerpt.prompt,
391            indoc! {r#"
392            ```main.rs
393            fn bar() {
394                let x = 42;
395                let mut sum = 0;
396                for i in 0..x {
397            <|editable_region_start|>
398                    sum += i;
399                }
400                println!("Sum: {}", sum);
401                r<|user_cursor_is_here|>eturn sum;
402            }
403
404            fn generate_random_numbers() -> Vec<i32> {
405            <|editable_region_end|>
406                let mut rng = rand::thread_rng();
407            ```"#}
408        );
409    }
410
411    #[gpui::test]
412    fn test_parse_edits_empty_editable_region(cx: &mut App) {
413        let text = "fn foo() {\n    let x = 42;\n}\n";
414        let buffer = cx.new(|cx| Buffer::local(text, cx));
415        let snapshot = buffer.read(cx).snapshot();
416
417        let output = "<|editable_region_start|>\n<|editable_region_end|>";
418        let editable_range = 0..text.len();
419        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
420        assert_eq!(edits.len(), 1);
421        let (range, new_text) = &edits[0];
422        assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
423        assert_eq!(new_text.as_ref(), "");
424    }
425
426    #[gpui::test]
427    fn test_parse_edits_multibyte_char_before_end_marker(cx: &mut App) {
428        let text = "// café";
429        let buffer = cx.new(|cx| Buffer::local(text, cx));
430        let snapshot = buffer.read(cx).snapshot();
431
432        let output = "<|editable_region_start|>\n// café<|editable_region_end|>";
433        let editable_range = 0..text.len();
434
435        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
436        assert_eq!(edits, vec![]);
437    }
438
439    #[gpui::test]
440    fn test_parse_edits_multibyte_char_after_start_marker(cx: &mut App) {
441        let text = "é is great";
442        let buffer = cx.new(|cx| Buffer::local(text, cx));
443        let snapshot = buffer.read(cx).snapshot();
444
445        let output = "<|editable_region_start|>é is great\n<|editable_region_end|>";
446        let editable_range = 0..text.len();
447
448        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
449        assert!(edits.is_empty());
450    }
451}