zeta1.rs

  1use std::{fmt::Write, ops::Range, sync::Arc};
  2
  3use crate::cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count};
  4use anyhow::Result;
  5use cloud_llm_client::PredictEditsBody;
  6use edit_prediction_types::PredictedCursorPosition;
  7use language::{Anchor, BufferSnapshot, Point, text_diff};
  8use text::Bias;
  9use zeta_prompt::{
 10    Event,
 11    zeta1::{
 12        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 13        START_OF_FILE_MARKER,
 14    },
 15};
 16
 17pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 18pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 19
 20pub(crate) fn parse_edits(
 21    output_excerpt: &str,
 22    editable_range: Range<usize>,
 23    snapshot: &BufferSnapshot,
 24) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
 25    let content = output_excerpt.replace(CURSOR_MARKER, "");
 26
 27    let start_markers = content
 28        .match_indices(EDITABLE_REGION_START_MARKER)
 29        .collect::<Vec<_>>();
 30    anyhow::ensure!(
 31        start_markers.len() <= 1,
 32        "expected at most one start marker, found {}",
 33        start_markers.len()
 34    );
 35
 36    let end_markers = content
 37        .match_indices(EDITABLE_REGION_END_MARKER)
 38        .collect::<Vec<_>>();
 39    anyhow::ensure!(
 40        end_markers.len() <= 1,
 41        "expected at most one end marker, found {}",
 42        end_markers.len()
 43    );
 44
 45    let sof_markers = content
 46        .match_indices(START_OF_FILE_MARKER)
 47        .collect::<Vec<_>>();
 48    anyhow::ensure!(
 49        sof_markers.len() <= 1,
 50        "expected at most one start-of-file marker, found {}",
 51        sof_markers.len()
 52    );
 53
 54    let content_start = start_markers
 55        .first()
 56        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len())
 57        .map(|start| {
 58            if content.len() > start
 59                && content.is_char_boundary(start)
 60                && content[start..].starts_with('\n')
 61            {
 62                start + 1
 63            } else {
 64                start
 65            }
 66        })
 67        .unwrap_or(0);
 68    let content_end = end_markers
 69        .first()
 70        .map(|e| {
 71            if e.0 > 0 && content.is_char_boundary(e.0 - 1) && content[e.0 - 1..].starts_with('\n')
 72            {
 73                e.0 - 1
 74            } else {
 75                e.0
 76            }
 77        })
 78        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
 79
 80    // min to account for content_end and content_start both accounting for the same newline in the following case:
 81    // <|editable_region_start|>\n<|editable_region_end|>
 82    let new_text = &content[content_start.min(content_end)..content_end];
 83
 84    let old_text = snapshot
 85        .text_for_range(editable_range.clone())
 86        .collect::<String>();
 87
 88    Ok(compute_edits(
 89        old_text,
 90        new_text,
 91        editable_range.start,
 92        snapshot,
 93    ))
 94}
 95
 96pub fn compute_edits(
 97    old_text: String,
 98    new_text: &str,
 99    offset: usize,
100    snapshot: &BufferSnapshot,
101) -> Vec<(Range<Anchor>, Arc<str>)> {
102    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
103}
104
105pub fn compute_edits_and_cursor_position(
106    old_text: String,
107    new_text: &str,
108    offset: usize,
109    cursor_offset_in_new_text: Option<usize>,
110    snapshot: &BufferSnapshot,
111) -> (
112    Vec<(Range<Anchor>, Arc<str>)>,
113    Option<PredictedCursorPosition>,
114) {
115    let diffs = text_diff(&old_text, new_text);
116
117    // Delta represents the cumulative change in byte count from all preceding edits.
118    // new_offset = old_offset + delta, so old_offset = new_offset - delta
119    let mut delta: isize = 0;
120    let mut cursor_position: Option<PredictedCursorPosition> = None;
121
122    let edits = diffs
123        .iter()
124        .map(|(raw_old_range, new_text)| {
125            // Compute cursor position if it falls within or before this edit.
126            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
127                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
128                let edit_end_in_new = edit_start_in_new + new_text.len();
129
130                if cursor_offset < edit_start_in_new {
131                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
132                    cursor_position = Some(PredictedCursorPosition::at_anchor(
133                        snapshot.anchor_after(offset + cursor_in_old),
134                    ));
135                } else if cursor_offset < edit_end_in_new {
136                    let offset_within_insertion = cursor_offset - edit_start_in_new;
137                    cursor_position = Some(PredictedCursorPosition::new(
138                        snapshot.anchor_before(offset + raw_old_range.start),
139                        offset_within_insertion,
140                    ));
141                }
142
143                delta += new_text.len() as isize - raw_old_range.len() as isize;
144            }
145
146            // Compute the edit with prefix/suffix trimming.
147            let mut old_range = raw_old_range.clone();
148            let old_slice = &old_text[old_range.clone()];
149
150            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
151            let suffix_len = common_prefix(
152                old_slice[prefix_len..].chars().rev(),
153                new_text[prefix_len..].chars().rev(),
154            );
155
156            old_range.start += offset;
157            old_range.end += offset;
158            old_range.start += prefix_len;
159            old_range.end -= suffix_len;
160
161            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
162            let range = if old_range.is_empty() {
163                let anchor = snapshot.anchor_after(old_range.start);
164                anchor..anchor
165            } else {
166                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
167            };
168            (range, new_text)
169        })
170        .collect();
171
172    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
173        let cursor_in_old = (cursor_offset as isize - delta) as usize;
174        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
175        cursor_position = Some(PredictedCursorPosition::at_anchor(
176            snapshot.anchor_after(buffer_offset),
177        ));
178    }
179
180    (edits, cursor_position)
181}
182
183fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
184    a.zip(b)
185        .take_while(|(a, b)| a == b)
186        .map(|(a, _)| a.len_utf8())
187        .sum()
188}
189
190pub struct GatherContextOutput {
191    pub body: PredictEditsBody,
192    pub context_range: Range<Point>,
193    pub editable_range: Range<usize>,
194    pub included_events_count: usize,
195}
196
197pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
198    prompt_for_events_impl(events, max_tokens).0
199}
200
201fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
202    let mut result = String::new();
203    for (ix, event) in events.iter().rev().enumerate() {
204        let event_string = format_event(event.as_ref());
205        let event_tokens = guess_token_count(event_string.len());
206        if event_tokens > remaining_tokens {
207            return (result, ix);
208        }
209
210        if !result.is_empty() {
211            result.insert_str(0, "\n\n");
212        }
213        result.insert_str(0, &event_string);
214        remaining_tokens -= event_tokens;
215    }
216    return (result, events.len());
217}
218
219pub fn format_event(event: &Event) -> String {
220    match event {
221        Event::BufferChange {
222            path,
223            old_path,
224            diff,
225            ..
226        } => {
227            let mut prompt = String::new();
228
229            if old_path != path {
230                writeln!(
231                    prompt,
232                    "User renamed {} to {}\n",
233                    old_path.display(),
234                    path.display()
235                )
236                .unwrap();
237            }
238
239            if !diff.is_empty() {
240                write!(
241                    prompt,
242                    "User edited {}:\n```diff\n{}\n```",
243                    path.display(),
244                    diff
245                )
246                .unwrap();
247            }
248
249            prompt
250        }
251    }
252}
253
254#[derive(Debug)]
255pub struct InputExcerpt {
256    pub context_range: Range<Point>,
257    pub editable_range: Range<Point>,
258    pub prompt: String,
259}
260
261pub fn excerpt_for_cursor_position(
262    position: Point,
263    path: &str,
264    snapshot: &BufferSnapshot,
265    editable_region_token_limit: usize,
266    context_token_limit: usize,
267) -> InputExcerpt {
268    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
269        position,
270        snapshot,
271        editable_region_token_limit,
272        context_token_limit,
273    );
274
275    let mut prompt = String::new();
276
277    writeln!(&mut prompt, "```{path}").unwrap();
278    if context_range.start == Point::zero() {
279        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
280    }
281
282    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
283        prompt.push_str(chunk.text);
284    }
285
286    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
287
288    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
289        prompt.push_str(chunk.text);
290    }
291    write!(prompt, "\n```").unwrap();
292
293    InputExcerpt {
294        context_range,
295        editable_range,
296        prompt,
297    }
298}
299
300fn push_editable_range(
301    cursor_position: Point,
302    snapshot: &BufferSnapshot,
303    editable_range: Range<Point>,
304    prompt: &mut String,
305) {
306    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
307    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
308        prompt.push_str(chunk.text);
309    }
310    prompt.push_str(CURSOR_MARKER);
311    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
312        prompt.push_str(chunk.text);
313    }
314    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use gpui::{App, AppContext};
321    use indoc::indoc;
322    use language::Buffer;
323    use text::OffsetRangeExt as _;
324
325    #[gpui::test]
326    fn test_excerpt_for_cursor_position(cx: &mut App) {
327        let text = indoc! {r#"
328            fn foo() {
329                let x = 42;
330                println!("Hello, world!");
331            }
332
333            fn bar() {
334                let x = 42;
335                let mut sum = 0;
336                for i in 0..x {
337                    sum += i;
338                }
339                println!("Sum: {}", sum);
340                return sum;
341            }
342
343            fn generate_random_numbers() -> Vec<i32> {
344                let mut rng = rand::thread_rng();
345                let mut numbers = Vec::new();
346                for _ in 0..5 {
347                    numbers.push(rng.random_range(1..101));
348                }
349                numbers
350            }
351        "#};
352        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
353        let snapshot = buffer.read(cx).snapshot();
354
355        // The excerpt expands to syntax boundaries.
356        // With 50 token editable limit, we get a region that expands to syntax nodes.
357        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
358        assert_eq!(
359            excerpt.prompt,
360            indoc! {r#"
361            ```main.rs
362
363            fn bar() {
364                let x = 42;
365            <|editable_region_start|>
366                let mut sum = 0;
367                for i in 0..x {
368                    sum += i;
369                }
370                println!("Sum: {}", sum);
371                r<|user_cursor_is_here|>eturn sum;
372            }
373
374            fn generate_random_numbers() -> Vec<i32> {
375            <|editable_region_end|>
376                let mut rng = rand::thread_rng();
377                let mut numbers = Vec::new();
378            ```"#}
379        );
380
381        // With smaller budget, the region expands to syntax boundaries but is tighter.
382        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
383        assert_eq!(
384            excerpt.prompt,
385            indoc! {r#"
386            ```main.rs
387            fn bar() {
388                let x = 42;
389                let mut sum = 0;
390                for i in 0..x {
391            <|editable_region_start|>
392                    sum += i;
393                }
394                println!("Sum: {}", sum);
395                r<|user_cursor_is_here|>eturn sum;
396            }
397
398            fn generate_random_numbers() -> Vec<i32> {
399            <|editable_region_end|>
400                let mut rng = rand::thread_rng();
401            ```"#}
402        );
403    }
404
405    #[gpui::test]
406    fn test_parse_edits_empty_editable_region(cx: &mut App) {
407        let text = "fn foo() {\n    let x = 42;\n}\n";
408        let buffer = cx.new(|cx| Buffer::local(text, cx));
409        let snapshot = buffer.read(cx).snapshot();
410
411        let output = "<|editable_region_start|>\n<|editable_region_end|>";
412        let editable_range = 0..text.len();
413        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
414        assert_eq!(edits.len(), 1);
415        let (range, new_text) = &edits[0];
416        assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
417        assert_eq!(new_text.as_ref(), "");
418    }
419
420    #[gpui::test]
421    fn test_parse_edits_multibyte_char_before_end_marker(cx: &mut App) {
422        let text = "// café";
423        let buffer = cx.new(|cx| Buffer::local(text, cx));
424        let snapshot = buffer.read(cx).snapshot();
425
426        let output = "<|editable_region_start|>\n// café<|editable_region_end|>";
427        let editable_range = 0..text.len();
428
429        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
430        assert_eq!(edits, vec![]);
431    }
432
433    #[gpui::test]
434    fn test_parse_edits_multibyte_char_after_start_marker(cx: &mut App) {
435        let text = "é is great";
436        let buffer = cx.new(|cx| Buffer::local(text, cx));
437        let snapshot = buffer.read(cx).snapshot();
438
439        let output = "<|editable_region_start|>é is great\n<|editable_region_end|>";
440        let editable_range = 0..text.len();
441
442        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
443        assert!(edits.is_empty());
444    }
445}