edit_action.rs

  1use util::ResultExt;
  2
  3/// Represents an edit action to be performed on a file.
  4#[derive(Debug, Clone, PartialEq, Eq)]
  5pub enum EditAction {
  6    /// Replace specific content in a file with new content
  7    Replace {
  8        file_path: String,
  9        old: String,
 10        new: String,
 11    },
 12    /// Write content to a file (create or overwrite)
 13    Write { file_path: String, content: String },
 14}
 15
 16impl EditAction {
 17    pub fn file_path(&self) -> &str {
 18        match self {
 19            EditAction::Replace { file_path, .. } => file_path,
 20            EditAction::Write { file_path, .. } => file_path,
 21        }
 22    }
 23}
 24
 25/// Parses edit actions from an LLM response.
 26/// See system.md for more details on the format.
 27#[derive(Debug)]
 28pub struct EditActionParser {
 29    state: State,
 30    pre_fence_line: Vec<u8>,
 31    marker_ix: usize,
 32    line: usize,
 33    column: usize,
 34    old_bytes: Vec<u8>,
 35    new_bytes: Vec<u8>,
 36    errors: Vec<ParseError>,
 37}
 38
 39#[derive(Debug, PartialEq, Eq)]
 40enum State {
 41    /// Anywhere outside an action
 42    Default,
 43    /// After opening ```, in optional language tag
 44    OpenFence,
 45    /// In SEARCH marker
 46    SearchMarker,
 47    /// In search block or divider
 48    SearchBlock,
 49    /// In replace block or REPLACE marker
 50    ReplaceBlock,
 51    /// In closing ```
 52    CloseFence,
 53}
 54
 55impl EditActionParser {
 56    /// Creates a new `EditActionParser`
 57    pub fn new() -> Self {
 58        Self {
 59            state: State::Default,
 60            pre_fence_line: Vec::new(),
 61            marker_ix: 0,
 62            line: 1,
 63            column: 0,
 64            old_bytes: Vec::new(),
 65            new_bytes: Vec::new(),
 66            errors: Vec::new(),
 67        }
 68    }
 69
 70    /// Processes a chunk of input text and returns any completed edit actions.
 71    ///
 72    /// This method can be called repeatedly with fragments of input. The parser
 73    /// maintains its state between calls, allowing you to process streaming input
 74    /// as it becomes available. Actions are only inserted once they are fully parsed.
 75    ///
 76    /// If a block fails to parse, it will simply be skipped and an error will be recorded.
 77    /// All errors can be accessed through the `EditActionsParser::errors` method.
 78    pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
 79        use State::*;
 80
 81        const FENCE: &[u8] = b"\n```";
 82        const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH\n";
 83        const DIVIDER: &[u8] = b"=======\n";
 84        const NL_DIVIDER: &[u8] = b"\n=======\n";
 85        const REPLACE_MARKER: &[u8] = b">>>>>>> REPLACE";
 86        const NL_REPLACE_MARKER: &[u8] = b"\n>>>>>>> REPLACE";
 87
 88        let mut actions = Vec::new();
 89
 90        for byte in input.bytes() {
 91            // Update line and column tracking
 92            if byte == b'\n' {
 93                self.line += 1;
 94                self.column = 0;
 95            } else {
 96                self.column += 1;
 97            }
 98
 99            match self.state {
100                Default => match match_marker(byte, FENCE, &mut self.marker_ix) {
101                    MarkerMatch::Complete => {
102                        self.to_state(OpenFence);
103                    }
104                    MarkerMatch::Partial => {}
105                    MarkerMatch::None => {
106                        if self.marker_ix > 0 {
107                            self.marker_ix = 0;
108                            self.pre_fence_line.clear();
109                        }
110
111                        if byte != b'\n' {
112                            self.pre_fence_line.push(byte);
113                        }
114                    }
115                },
116                OpenFence => {
117                    // skip language tag
118                    if byte == b'\n' {
119                        self.to_state(SearchMarker);
120                    }
121                }
122                SearchMarker => {
123                    if self.expect_marker(byte, SEARCH_MARKER) {
124                        self.to_state(SearchBlock);
125                    }
126                }
127                SearchBlock => {
128                    if collect_until_marker(
129                        byte,
130                        DIVIDER,
131                        NL_DIVIDER,
132                        &mut self.marker_ix,
133                        &mut self.old_bytes,
134                    ) {
135                        self.to_state(ReplaceBlock);
136                    }
137                }
138                ReplaceBlock => {
139                    if collect_until_marker(
140                        byte,
141                        REPLACE_MARKER,
142                        NL_REPLACE_MARKER,
143                        &mut self.marker_ix,
144                        &mut self.new_bytes,
145                    ) {
146                        self.to_state(CloseFence);
147                    }
148                }
149                CloseFence => {
150                    if self.expect_marker(byte, FENCE) {
151                        if let Some(action) = self.action() {
152                            actions.push(action);
153                        }
154                        self.reset();
155                    }
156                }
157            };
158        }
159
160        actions
161    }
162
163    /// Returns a reference to the errors encountered during parsing.
164    pub fn errors(&self) -> &[ParseError] {
165        &self.errors
166    }
167
168    fn action(&mut self) -> Option<EditAction> {
169        if self.old_bytes.is_empty() && self.new_bytes.is_empty() {
170            self.push_error(ParseErrorKind::NoOp);
171            return None;
172        }
173
174        let file_path = String::from_utf8(std::mem::take(&mut self.pre_fence_line)).log_err()?;
175        let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
176
177        if self.old_bytes.is_empty() {
178            Some(EditAction::Write { file_path, content })
179        } else {
180            let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?;
181
182            Some(EditAction::Replace {
183                file_path,
184                old,
185                new: content,
186            })
187        }
188    }
189
190    fn expect_marker(&mut self, byte: u8, marker: &'static [u8]) -> bool {
191        match match_marker(byte, marker, &mut self.marker_ix) {
192            MarkerMatch::Complete => true,
193            MarkerMatch::Partial => false,
194            MarkerMatch::None => {
195                self.push_error(ParseErrorKind::ExpectedMarker {
196                    expected: marker,
197                    found: byte,
198                });
199                self.reset();
200                false
201            }
202        }
203    }
204
205    fn to_state(&mut self, state: State) {
206        self.state = state;
207        self.marker_ix = 0;
208    }
209
210    fn reset(&mut self) {
211        self.pre_fence_line.clear();
212        self.old_bytes.clear();
213        self.new_bytes.clear();
214        self.to_state(State::Default);
215    }
216
217    fn push_error(&mut self, kind: ParseErrorKind) {
218        self.errors.push(ParseError {
219            line: self.line,
220            column: self.column,
221            kind,
222        });
223    }
224}
225
226#[derive(Debug)]
227enum MarkerMatch {
228    None,
229    Partial,
230    Complete,
231}
232
233fn match_marker(byte: u8, marker: &[u8], marker_ix: &mut usize) -> MarkerMatch {
234    if byte == marker[*marker_ix] {
235        *marker_ix += 1;
236
237        if *marker_ix >= marker.len() {
238            MarkerMatch::Complete
239        } else {
240            MarkerMatch::Partial
241        }
242    } else {
243        MarkerMatch::None
244    }
245}
246
247fn collect_until_marker(
248    byte: u8,
249    marker: &[u8],
250    nl_marker: &[u8],
251    marker_ix: &mut usize,
252    buf: &mut Vec<u8>,
253) -> bool {
254    let marker = if buf.is_empty() {
255        // do not require another newline if block is empty
256        marker
257    } else {
258        nl_marker
259    };
260
261    match match_marker(byte, marker, marker_ix) {
262        MarkerMatch::Complete => true,
263        MarkerMatch::Partial => false,
264        MarkerMatch::None => {
265            if *marker_ix > 0 {
266                buf.extend_from_slice(&marker[..*marker_ix]);
267                *marker_ix = 0;
268
269                // The beginning of marker might match current byte
270                match match_marker(byte, marker, marker_ix) {
271                    MarkerMatch::Complete => return true,
272                    MarkerMatch::Partial => return false,
273                    MarkerMatch::None => { /* no match, keep collecting */ }
274                }
275            }
276
277            buf.push(byte);
278
279            false
280        }
281    }
282}
283
284#[derive(Debug, PartialEq, Eq)]
285pub struct ParseError {
286    line: usize,
287    column: usize,
288    kind: ParseErrorKind,
289}
290
291#[derive(Debug, PartialEq, Eq)]
292pub enum ParseErrorKind {
293    ExpectedMarker { expected: &'static [u8], found: u8 },
294    NoOp,
295}
296
297impl std::fmt::Display for ParseErrorKind {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        match self {
300            ParseErrorKind::ExpectedMarker { expected, found } => {
301                write!(
302                    f,
303                    "Expected marker {:?}, found {:?}",
304                    String::from_utf8_lossy(expected),
305                    *found as char
306                )
307            }
308            ParseErrorKind::NoOp => {
309                write!(f, "No search or replace")
310            }
311        }
312    }
313}
314
315impl std::fmt::Display for ParseError {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        write!(f, "input:{}:{}: {}", self.line, self.column, self.kind)
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use rand::prelude::*;
325
326    #[test]
327    fn test_simple_edit_action() {
328        let input = r#"src/main.rs
329```
330<<<<<<< SEARCH
331fn original() {}
332=======
333fn replacement() {}
334>>>>>>> REPLACE
335```
336"#;
337
338        let mut parser = EditActionParser::new();
339        let actions = parser.parse_chunk(input);
340
341        assert_eq!(actions.len(), 1);
342        assert_eq!(
343            actions[0],
344            EditAction::Replace {
345                file_path: "src/main.rs".to_string(),
346                old: "fn original() {}".to_string(),
347                new: "fn replacement() {}".to_string(),
348            }
349        );
350    }
351
352    #[test]
353    fn test_with_language_tag() {
354        let input = r#"src/main.rs
355```rust
356<<<<<<< SEARCH
357fn original() {}
358=======
359fn replacement() {}
360>>>>>>> REPLACE
361```
362"#;
363
364        let mut parser = EditActionParser::new();
365        let actions = parser.parse_chunk(input);
366
367        assert_eq!(actions.len(), 1);
368        assert_eq!(
369            actions[0],
370            EditAction::Replace {
371                file_path: "src/main.rs".to_string(),
372                old: "fn original() {}".to_string(),
373                new: "fn replacement() {}".to_string(),
374            }
375        );
376    }
377
378    #[test]
379    fn test_with_surrounding_text() {
380        let input = r#"Here's a modification I'd like to make to the file:
381
382src/main.rs
383```rust
384<<<<<<< SEARCH
385fn original() {}
386=======
387fn replacement() {}
388>>>>>>> REPLACE
389```
390
391This change makes the function better.
392"#;
393
394        let mut parser = EditActionParser::new();
395        let actions = parser.parse_chunk(input);
396
397        assert_eq!(actions.len(), 1);
398        assert_eq!(
399            actions[0],
400            EditAction::Replace {
401                file_path: "src/main.rs".to_string(),
402                old: "fn original() {}".to_string(),
403                new: "fn replacement() {}".to_string(),
404            }
405        );
406    }
407
408    #[test]
409    fn test_multiple_edit_actions() {
410        let input = r#"First change:
411src/main.rs
412```
413<<<<<<< SEARCH
414fn original() {}
415=======
416fn replacement() {}
417>>>>>>> REPLACE
418```
419
420Second change:
421src/utils.rs
422```rust
423<<<<<<< SEARCH
424fn old_util() -> bool { false }
425=======
426fn new_util() -> bool { true }
427>>>>>>> REPLACE
428```
429"#;
430
431        let mut parser = EditActionParser::new();
432        let actions = parser.parse_chunk(input);
433
434        assert_eq!(actions.len(), 2);
435        assert_eq!(
436            actions[0],
437            EditAction::Replace {
438                file_path: "src/main.rs".to_string(),
439                old: "fn original() {}".to_string(),
440                new: "fn replacement() {}".to_string(),
441            }
442        );
443        assert_eq!(
444            actions[1],
445            EditAction::Replace {
446                file_path: "src/utils.rs".to_string(),
447                old: "fn old_util() -> bool { false }".to_string(),
448                new: "fn new_util() -> bool { true }".to_string(),
449            }
450        );
451    }
452
453    #[test]
454    fn test_multiline() {
455        let input = r#"src/main.rs
456```rust
457<<<<<<< SEARCH
458fn original() {
459    println!("This is the original function");
460    let x = 42;
461    if x > 0 {
462        println!("Positive number");
463    }
464}
465=======
466fn replacement() {
467    println!("This is the replacement function");
468    let x = 100;
469    if x > 50 {
470        println!("Large number");
471    } else {
472        println!("Small number");
473    }
474}
475>>>>>>> REPLACE
476```
477"#;
478
479        let mut parser = EditActionParser::new();
480        let actions = parser.parse_chunk(input);
481
482        assert_eq!(actions.len(), 1);
483        assert_eq!(
484            actions[0],
485            EditAction::Replace {
486                file_path: "src/main.rs".to_string(),
487                old: "fn original() {\n    println!(\"This is the original function\");\n    let x = 42;\n    if x > 0 {\n        println!(\"Positive number\");\n    }\n}".to_string(),
488                new: "fn replacement() {\n    println!(\"This is the replacement function\");\n    let x = 100;\n    if x > 50 {\n        println!(\"Large number\");\n    } else {\n        println!(\"Small number\");\n    }\n}".to_string(),
489            }
490        );
491    }
492
493    #[test]
494    fn test_write_action() {
495        let input = r#"Create a new main.rs file:
496
497src/main.rs
498```rust
499<<<<<<< SEARCH
500=======
501fn new_function() {
502    println!("This function is being added");
503}
504>>>>>>> REPLACE
505```
506"#;
507
508        let mut parser = EditActionParser::new();
509        let actions = parser.parse_chunk(input);
510
511        assert_eq!(actions.len(), 1);
512        assert_eq!(
513            actions[0],
514            EditAction::Write {
515                file_path: "src/main.rs".to_string(),
516                content: "fn new_function() {\n    println!(\"This function is being added\");\n}"
517                    .to_string(),
518            }
519        );
520    }
521
522    #[test]
523    fn test_empty_replace() {
524        let input = r#"src/main.rs
525```rust
526<<<<<<< SEARCH
527fn this_will_be_deleted() {
528    println!("Deleting this function");
529}
530=======
531>>>>>>> REPLACE
532```
533"#;
534
535        let mut parser = EditActionParser::new();
536        let actions = parser.parse_chunk(input);
537
538        assert_eq!(actions.len(), 1);
539        assert_eq!(
540            actions[0],
541            EditAction::Replace {
542                file_path: "src/main.rs".to_string(),
543                old: "fn this_will_be_deleted() {\n    println!(\"Deleting this function\");\n}"
544                    .to_string(),
545                new: "".to_string(),
546            }
547        );
548    }
549
550    #[test]
551    fn test_empty_both() {
552        let input = r#"src/main.rs
553```rust
554<<<<<<< SEARCH
555=======
556>>>>>>> REPLACE
557```
558"#;
559
560        let mut parser = EditActionParser::new();
561        let actions = parser.parse_chunk(input);
562
563        // Should not create an action when both sections are empty
564        assert_eq!(actions.len(), 0);
565
566        // Check that the NoOp error was added
567        assert_eq!(parser.errors().len(), 1);
568        match parser.errors()[0].kind {
569            ParseErrorKind::NoOp => {}
570            _ => panic!("Expected NoOp error"),
571        }
572    }
573
574    #[test]
575    fn test_resumability() {
576        let input_part1 = r#"src/main.rs
577```rust
578<<<<<<< SEARCH
579fn ori"#;
580
581        let input_part2 = r#"ginal() {}
582=======
583fn replacement() {}"#;
584
585        let input_part3 = r#"
586>>>>>>> REPLACE
587```
588"#;
589
590        let mut parser = EditActionParser::new();
591        let actions1 = parser.parse_chunk(input_part1);
592        assert_eq!(actions1.len(), 0);
593
594        let actions2 = parser.parse_chunk(input_part2);
595        // No actions should be complete yet
596        assert_eq!(actions2.len(), 0);
597
598        let actions3 = parser.parse_chunk(input_part3);
599        // The third chunk should complete the action
600        assert_eq!(actions3.len(), 1);
601        assert_eq!(
602            actions3[0],
603            EditAction::Replace {
604                file_path: "src/main.rs".to_string(),
605                old: "fn original() {}".to_string(),
606                new: "fn replacement() {}".to_string(),
607            }
608        );
609    }
610
611    #[test]
612    fn test_parser_state_preservation() {
613        let mut parser = EditActionParser::new();
614        let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n");
615
616        // Check parser is in the correct state
617        assert_eq!(parser.state, State::SearchBlock);
618        assert_eq!(parser.pre_fence_line, b"src/main.rs");
619
620        // Continue parsing
621        let actions2 = parser.parse_chunk("original code\n=======\n");
622        assert_eq!(parser.state, State::ReplaceBlock);
623        assert_eq!(parser.old_bytes, b"original code");
624
625        let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
626
627        // After complete parsing, state should reset
628        assert_eq!(parser.state, State::Default);
629        assert!(parser.pre_fence_line.is_empty());
630        assert!(parser.old_bytes.is_empty());
631        assert!(parser.new_bytes.is_empty());
632
633        assert_eq!(actions1.len(), 0);
634        assert_eq!(actions2.len(), 0);
635        assert_eq!(actions3.len(), 1);
636    }
637
638    #[test]
639    fn test_invalid_search_marker() {
640        let input = r#"src/main.rs
641```rust
642<<<<<<< WRONG_MARKER
643fn original() {}
644=======
645fn replacement() {}
646>>>>>>> REPLACE
647```
648"#;
649
650        let mut parser = EditActionParser::new();
651        let actions = parser.parse_chunk(input);
652        assert_eq!(actions.len(), 0);
653
654        assert_eq!(parser.errors().len(), 1);
655        let error = &parser.errors()[0];
656
657        assert_eq!(error.line, 3);
658        assert_eq!(error.column, 9);
659        assert_eq!(
660            error.kind,
661            ParseErrorKind::ExpectedMarker {
662                expected: b"<<<<<<< SEARCH\n",
663                found: b'W'
664            }
665        );
666    }
667
668    #[test]
669    fn test_missing_closing_fence() {
670        let input = r#"src/main.rs
671```rust
672<<<<<<< SEARCH
673fn original() {}
674=======
675fn replacement() {}
676>>>>>>> REPLACE
677<!-- Missing closing fence -->
678
679src/utils.rs
680```rust
681<<<<<<< SEARCH
682fn utils_func() {}
683=======
684fn new_utils_func() {}
685>>>>>>> REPLACE
686```
687"#;
688
689        let mut parser = EditActionParser::new();
690        let actions = parser.parse_chunk(input);
691
692        // Only the second block should be parsed
693        assert_eq!(actions.len(), 1);
694        assert_eq!(
695            actions[0],
696            EditAction::Replace {
697                file_path: "src/utils.rs".to_string(),
698                old: "fn utils_func() {}".to_string(),
699                new: "fn new_utils_func() {}".to_string(),
700            }
701        );
702
703        // The parser should continue after an error
704        assert_eq!(parser.state, State::Default);
705    }
706
707    const SYSTEM_PROMPT: &str = include_str!("./edit_prompt.md");
708
709    #[test]
710    fn test_parse_examples_in_system_prompt() {
711        let mut parser = EditActionParser::new();
712        let actions = parser.parse_chunk(SYSTEM_PROMPT);
713        assert_examples_in_system_prompt(&actions, parser.errors());
714    }
715
716    #[gpui::test(iterations = 10)]
717    fn test_random_chunking_of_system_prompt(mut rng: StdRng) {
718        let mut parser = EditActionParser::new();
719        let mut remaining = SYSTEM_PROMPT;
720        let mut actions = Vec::with_capacity(5);
721
722        while !remaining.is_empty() {
723            let chunk_size = rng.gen_range(1..=std::cmp::min(remaining.len(), 100));
724
725            let (chunk, rest) = remaining.split_at(chunk_size);
726
727            actions.extend(parser.parse_chunk(chunk));
728            remaining = rest;
729        }
730
731        assert_examples_in_system_prompt(&actions, parser.errors());
732    }
733
734    fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) {
735        assert_eq!(actions.len(), 5);
736
737        assert_eq!(
738            actions[0],
739            EditAction::Replace {
740                file_path: "mathweb/flask/app.py".to_string(),
741                old: "from flask import Flask".to_string(),
742                new: "import math\nfrom flask import Flask".to_string(),
743            }
744        );
745
746        assert_eq!(
747                    actions[1],
748                    EditAction::Replace {
749                        file_path: "mathweb/flask/app.py".to_string(),
750                        old: "def factorial(n):\n    \"compute factorial\"\n\n    if n == 0:\n        return 1\n    else:\n        return n * factorial(n-1)\n".to_string(),
751                        new: "".to_string(),
752                    }
753                );
754
755        assert_eq!(
756            actions[2],
757            EditAction::Replace {
758                file_path: "mathweb/flask/app.py".to_string(),
759                old: "    return str(factorial(n))".to_string(),
760                new: "    return str(math.factorial(n))".to_string(),
761            }
762        );
763
764        assert_eq!(
765            actions[3],
766            EditAction::Write {
767                file_path: "hello.py".to_string(),
768                content: "def hello():\n    \"print a greeting\"\n\n    print(\"hello\")"
769                    .to_string(),
770            }
771        );
772
773        assert_eq!(
774            actions[4],
775            EditAction::Replace {
776                file_path: "main.py".to_string(),
777                old: "def hello():\n    \"print a greeting\"\n\n    print(\"hello\")".to_string(),
778                new: "from hello import hello".to_string(),
779            }
780        );
781
782        // Ensure we have no parsing errors
783        assert!(errors.is_empty(), "Parsing errors found: {:?}", errors);
784    }
785
786    #[test]
787    fn test_print_error() {
788        let input = r#"src/main.rs
789```rust
790<<<<<<< WRONG_MARKER
791fn original() {}
792=======
793fn replacement() {}
794>>>>>>> REPLACE
795```
796"#;
797
798        let mut parser = EditActionParser::new();
799        parser.parse_chunk(input);
800
801        assert_eq!(parser.errors().len(), 1);
802        let error = &parser.errors()[0];
803        let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH\n", found 'W'"#;
804
805        assert_eq!(format!("{}", error), expected_error);
806    }
807}