edit_action.rs

  1use util::ResultExt;
  2
  3/// Represents an edit action to be performed on a file.
  4#[derive(Debug, Clone, PartialEq, Eq)]
  5pub enum EditAction {
  6    /// Replace specific content in a file with new content
  7    Replace {
  8        file_path: String,
  9        old: String,
 10        new: String,
 11    },
 12    /// Write content to a file (create or overwrite)
 13    Write { file_path: String, content: String },
 14}
 15
 16impl EditAction {
 17    pub fn file_path(&self) -> &str {
 18        match self {
 19            EditAction::Replace { file_path, .. } => file_path,
 20            EditAction::Write { file_path, .. } => file_path,
 21        }
 22    }
 23}
 24
 25/// Parses edit actions from an LLM response.
 26/// See system.md for more details on the format.
 27#[derive(Debug)]
 28pub struct EditActionParser {
 29    state: State,
 30    pre_fence_line: Vec<u8>,
 31    marker_ix: usize,
 32    line: usize,
 33    column: usize,
 34    old_bytes: Vec<u8>,
 35    new_bytes: Vec<u8>,
 36    errors: Vec<ParseError>,
 37}
 38
 39#[derive(Debug, PartialEq, Eq)]
 40enum State {
 41    /// Anywhere outside an action
 42    Default,
 43    /// After opening ```, in optional language tag
 44    OpenFence,
 45    /// In SEARCH marker
 46    SearchMarker,
 47    /// In search block or divider
 48    SearchBlock,
 49    /// In replace block or REPLACE marker
 50    ReplaceBlock,
 51    /// In closing ```
 52    CloseFence,
 53}
 54
 55impl EditActionParser {
 56    /// Creates a new `EditActionParser`
 57    pub fn new() -> Self {
 58        Self {
 59            state: State::Default,
 60            pre_fence_line: Vec::new(),
 61            marker_ix: 0,
 62            line: 1,
 63            column: 0,
 64            old_bytes: Vec::new(),
 65            new_bytes: Vec::new(),
 66            errors: Vec::new(),
 67        }
 68    }
 69
 70    /// Processes a chunk of input text and returns any completed edit actions.
 71    ///
 72    /// This method can be called repeatedly with fragments of input. The parser
 73    /// maintains its state between calls, allowing you to process streaming input
 74    /// as it becomes available. Actions are only inserted once they are fully parsed.
 75    ///
 76    /// If a block fails to parse, it will simply be skipped and an error will be recorded.
 77    /// All errors can be accessed through the `EditActionsParser::errors` method.
 78    pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
 79        use State::*;
 80
 81        const FENCE: &[u8] = b"```";
 82        const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH";
 83        const DIVIDER: &[u8] = b"=======";
 84        const NL_DIVIDER: &[u8] = b"\n=======";
 85        const REPLACE_MARKER: &[u8] = b">>>>>>> REPLACE";
 86        const NL_REPLACE_MARKER: &[u8] = b"\n>>>>>>> REPLACE";
 87
 88        let mut actions = Vec::new();
 89
 90        for byte in input.bytes() {
 91            // Update line and column tracking
 92            if byte == b'\n' {
 93                self.line += 1;
 94                self.column = 0;
 95            } else {
 96                self.column += 1;
 97            }
 98
 99            match &self.state {
100                Default => match match_marker(byte, FENCE, false, &mut self.marker_ix) {
101                    MarkerMatch::Complete => {
102                        self.to_state(OpenFence);
103                    }
104                    MarkerMatch::Partial => {}
105                    MarkerMatch::None => {
106                        if self.marker_ix > 0 {
107                            self.marker_ix = 0;
108                        } else if self.pre_fence_line.ends_with(b"\n") {
109                            self.pre_fence_line.clear();
110                        }
111
112                        self.pre_fence_line.push(byte);
113                    }
114                },
115                OpenFence => {
116                    // skip language tag
117                    if byte == b'\n' {
118                        self.to_state(SearchMarker);
119                    }
120                }
121                SearchMarker => {
122                    if self.expect_marker(byte, SEARCH_MARKER, true) {
123                        self.to_state(SearchBlock);
124                    }
125                }
126                SearchBlock => {
127                    if collect_until_marker(
128                        byte,
129                        DIVIDER,
130                        NL_DIVIDER,
131                        true,
132                        &mut self.marker_ix,
133                        &mut self.old_bytes,
134                    ) {
135                        self.to_state(ReplaceBlock);
136                    }
137                }
138                ReplaceBlock => {
139                    if collect_until_marker(
140                        byte,
141                        REPLACE_MARKER,
142                        NL_REPLACE_MARKER,
143                        true,
144                        &mut self.marker_ix,
145                        &mut self.new_bytes,
146                    ) {
147                        self.to_state(CloseFence);
148                    }
149                }
150                CloseFence => {
151                    if self.expect_marker(byte, FENCE, false) {
152                        if let Some(action) = self.action() {
153                            actions.push(action);
154                        }
155                        self.errors();
156                        self.reset();
157                    }
158                }
159            };
160        }
161
162        actions
163    }
164
165    /// Returns a reference to the errors encountered during parsing.
166    pub fn errors(&self) -> &[ParseError] {
167        &self.errors
168    }
169
170    fn action(&mut self) -> Option<EditAction> {
171        if self.old_bytes.is_empty() && self.new_bytes.is_empty() {
172            self.push_error(ParseErrorKind::NoOp);
173            return None;
174        }
175
176        let mut pre_fence_line = std::mem::take(&mut self.pre_fence_line);
177
178        if pre_fence_line.ends_with(b"\n") {
179            pre_fence_line.pop();
180
181            if pre_fence_line.ends_with(b"\r") {
182                pre_fence_line.pop();
183            }
184        }
185
186        let file_path = String::from_utf8(pre_fence_line).log_err()?;
187        let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
188
189        if self.old_bytes.is_empty() {
190            Some(EditAction::Write { file_path, content })
191        } else {
192            let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?;
193
194            Some(EditAction::Replace {
195                file_path,
196                old,
197                new: content,
198            })
199        }
200    }
201
202    fn expect_marker(&mut self, byte: u8, marker: &'static [u8], trailing_newline: bool) -> bool {
203        match match_marker(byte, marker, trailing_newline, &mut self.marker_ix) {
204            MarkerMatch::Complete => true,
205            MarkerMatch::Partial => false,
206            MarkerMatch::None => {
207                self.push_error(ParseErrorKind::ExpectedMarker {
208                    expected: marker,
209                    found: byte,
210                });
211                self.reset();
212                false
213            }
214        }
215    }
216
217    fn to_state(&mut self, state: State) {
218        self.state = state;
219        self.marker_ix = 0;
220    }
221
222    fn reset(&mut self) {
223        self.pre_fence_line.clear();
224        self.old_bytes.clear();
225        self.new_bytes.clear();
226        self.to_state(State::Default);
227    }
228
229    fn push_error(&mut self, kind: ParseErrorKind) {
230        self.errors.push(ParseError {
231            line: self.line,
232            column: self.column,
233            kind,
234        });
235    }
236}
237
238#[derive(Debug)]
239enum MarkerMatch {
240    None,
241    Partial,
242    Complete,
243}
244
245fn match_marker(
246    byte: u8,
247    marker: &[u8],
248    trailing_newline: bool,
249    marker_ix: &mut usize,
250) -> MarkerMatch {
251    if trailing_newline && *marker_ix >= marker.len() {
252        if byte == b'\n' {
253            MarkerMatch::Complete
254        } else if byte == b'\r' {
255            MarkerMatch::Complete
256        } else {
257            MarkerMatch::None
258        }
259    } else if byte == marker[*marker_ix] {
260        *marker_ix += 1;
261
262        if *marker_ix < marker.len() || trailing_newline {
263            MarkerMatch::Partial
264        } else {
265            MarkerMatch::Complete
266        }
267    } else {
268        MarkerMatch::None
269    }
270}
271
272fn collect_until_marker(
273    byte: u8,
274    marker: &[u8],
275    nl_marker: &[u8],
276    trailing_newline: bool,
277    marker_ix: &mut usize,
278    buf: &mut Vec<u8>,
279) -> bool {
280    let marker = if buf.is_empty() {
281        // do not require another newline if block is empty
282        marker
283    } else {
284        nl_marker
285    };
286
287    match match_marker(byte, marker, trailing_newline, marker_ix) {
288        MarkerMatch::Complete => true,
289        MarkerMatch::Partial => false,
290        MarkerMatch::None => {
291            if *marker_ix > 0 {
292                buf.extend_from_slice(&marker[..*marker_ix]);
293                *marker_ix = 0;
294
295                // The beginning of marker might match current byte
296                match match_marker(byte, marker, trailing_newline, marker_ix) {
297                    MarkerMatch::Complete => return true,
298                    MarkerMatch::Partial => return false,
299                    MarkerMatch::None => { /* no match, keep collecting */ }
300                }
301            }
302
303            buf.push(byte);
304
305            false
306        }
307    }
308}
309
310#[derive(Debug, PartialEq, Eq)]
311pub struct ParseError {
312    line: usize,
313    column: usize,
314    kind: ParseErrorKind,
315}
316
317#[derive(Debug, PartialEq, Eq)]
318pub enum ParseErrorKind {
319    ExpectedMarker { expected: &'static [u8], found: u8 },
320    NoOp,
321}
322
323impl std::fmt::Display for ParseErrorKind {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        match self {
326            ParseErrorKind::ExpectedMarker { expected, found } => {
327                write!(
328                    f,
329                    "Expected marker {:?}, found {:?}",
330                    String::from_utf8_lossy(expected),
331                    *found as char
332                )
333            }
334            ParseErrorKind::NoOp => {
335                write!(f, "No search or replace")
336            }
337        }
338    }
339}
340
341impl std::fmt::Display for ParseError {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        write!(f, "input:{}:{}: {}", self.line, self.column, self.kind)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use rand::prelude::*;
351
352    #[test]
353    fn test_simple_edit_action() {
354        let input = r#"src/main.rs
355```
356<<<<<<< SEARCH
357fn original() {}
358=======
359fn replacement() {}
360>>>>>>> REPLACE
361```
362"#;
363
364        let mut parser = EditActionParser::new();
365        let actions = parser.parse_chunk(input);
366
367        assert_eq!(actions.len(), 1);
368        assert_eq!(
369            actions[0],
370            EditAction::Replace {
371                file_path: "src/main.rs".to_string(),
372                old: "fn original() {}".to_string(),
373                new: "fn replacement() {}".to_string(),
374            }
375        );
376    }
377
378    #[test]
379    fn test_with_language_tag() {
380        let input = r#"src/main.rs
381```rust
382<<<<<<< SEARCH
383fn original() {}
384=======
385fn replacement() {}
386>>>>>>> REPLACE
387```
388"#;
389
390        let mut parser = EditActionParser::new();
391        let actions = parser.parse_chunk(input);
392
393        assert_eq!(actions.len(), 1);
394        assert_eq!(
395            actions[0],
396            EditAction::Replace {
397                file_path: "src/main.rs".to_string(),
398                old: "fn original() {}".to_string(),
399                new: "fn replacement() {}".to_string(),
400            }
401        );
402    }
403
404    #[test]
405    fn test_with_surrounding_text() {
406        let input = r#"Here's a modification I'd like to make to the file:
407
408src/main.rs
409```rust
410<<<<<<< SEARCH
411fn original() {}
412=======
413fn replacement() {}
414>>>>>>> REPLACE
415```
416
417This change makes the function better.
418"#;
419
420        let mut parser = EditActionParser::new();
421        let actions = parser.parse_chunk(input);
422
423        assert_eq!(actions.len(), 1);
424        assert_eq!(
425            actions[0],
426            EditAction::Replace {
427                file_path: "src/main.rs".to_string(),
428                old: "fn original() {}".to_string(),
429                new: "fn replacement() {}".to_string(),
430            }
431        );
432    }
433
434    #[test]
435    fn test_multiple_edit_actions() {
436        let input = r#"First change:
437src/main.rs
438```
439<<<<<<< SEARCH
440fn original() {}
441=======
442fn replacement() {}
443>>>>>>> REPLACE
444```
445
446Second change:
447src/utils.rs
448```rust
449<<<<<<< SEARCH
450fn old_util() -> bool { false }
451=======
452fn new_util() -> bool { true }
453>>>>>>> REPLACE
454```
455"#;
456
457        let mut parser = EditActionParser::new();
458        let actions = parser.parse_chunk(input);
459
460        assert_eq!(actions.len(), 2);
461        assert_eq!(
462            actions[0],
463            EditAction::Replace {
464                file_path: "src/main.rs".to_string(),
465                old: "fn original() {}".to_string(),
466                new: "fn replacement() {}".to_string(),
467            }
468        );
469        assert_eq!(
470            actions[1],
471            EditAction::Replace {
472                file_path: "src/utils.rs".to_string(),
473                old: "fn old_util() -> bool { false }".to_string(),
474                new: "fn new_util() -> bool { true }".to_string(),
475            }
476        );
477    }
478
479    #[test]
480    fn test_multiline() {
481        let input = r#"src/main.rs
482```rust
483<<<<<<< SEARCH
484fn original() {
485    println!("This is the original function");
486    let x = 42;
487    if x > 0 {
488        println!("Positive number");
489    }
490}
491=======
492fn replacement() {
493    println!("This is the replacement function");
494    let x = 100;
495    if x > 50 {
496        println!("Large number");
497    } else {
498        println!("Small number");
499    }
500}
501>>>>>>> REPLACE
502```
503"#;
504
505        let mut parser = EditActionParser::new();
506        let actions = parser.parse_chunk(input);
507
508        assert_eq!(actions.len(), 1);
509        assert_eq!(
510            actions[0],
511            EditAction::Replace {
512                file_path: "src/main.rs".to_string(),
513                old: "fn original() {\n    println!(\"This is the original function\");\n    let x = 42;\n    if x > 0 {\n        println!(\"Positive number\");\n    }\n}".to_string(),
514                new: "fn replacement() {\n    println!(\"This is the replacement function\");\n    let x = 100;\n    if x > 50 {\n        println!(\"Large number\");\n    } else {\n        println!(\"Small number\");\n    }\n}".to_string(),
515            }
516        );
517    }
518
519    #[test]
520    fn test_write_action() {
521        let input = r#"Create a new main.rs file:
522
523src/main.rs
524```rust
525<<<<<<< SEARCH
526=======
527fn new_function() {
528    println!("This function is being added");
529}
530>>>>>>> REPLACE
531```
532"#;
533
534        let mut parser = EditActionParser::new();
535        let actions = parser.parse_chunk(input);
536
537        assert_eq!(actions.len(), 1);
538        assert_eq!(
539            actions[0],
540            EditAction::Write {
541                file_path: "src/main.rs".to_string(),
542                content: "fn new_function() {\n    println!(\"This function is being added\");\n}"
543                    .to_string(),
544            }
545        );
546    }
547
548    #[test]
549    fn test_empty_replace() {
550        let input = r#"src/main.rs
551```rust
552<<<<<<< SEARCH
553fn this_will_be_deleted() {
554    println!("Deleting this function");
555}
556=======
557>>>>>>> REPLACE
558```
559"#;
560
561        let mut parser = EditActionParser::new();
562        let actions = parser.parse_chunk(input);
563
564        assert_eq!(actions.len(), 1);
565        assert_eq!(
566            actions[0],
567            EditAction::Replace {
568                file_path: "src/main.rs".to_string(),
569                old: "fn this_will_be_deleted() {\n    println!(\"Deleting this function\");\n}"
570                    .to_string(),
571                new: "".to_string(),
572            }
573        );
574    }
575
576    #[test]
577    fn test_empty_both() {
578        let input = r#"src/main.rs
579```rust
580<<<<<<< SEARCH
581=======
582>>>>>>> REPLACE
583```
584"#;
585
586        let mut parser = EditActionParser::new();
587        let actions = parser.parse_chunk(input);
588
589        // Should not create an action when both sections are empty
590        assert_eq!(actions.len(), 0);
591
592        // Check that the NoOp error was added
593        assert_eq!(parser.errors().len(), 1);
594        match parser.errors()[0].kind {
595            ParseErrorKind::NoOp => {}
596            _ => panic!("Expected NoOp error"),
597        }
598    }
599
600    #[test]
601    fn test_resumability() {
602        let input_part1 = r#"src/main.rs
603```rust
604<<<<<<< SEARCH
605fn ori"#;
606
607        let input_part2 = r#"ginal() {}
608=======
609fn replacement() {}"#;
610
611        let input_part3 = r#"
612>>>>>>> REPLACE
613```
614"#;
615
616        let mut parser = EditActionParser::new();
617        let actions1 = parser.parse_chunk(input_part1);
618        assert_eq!(actions1.len(), 0);
619
620        let actions2 = parser.parse_chunk(input_part2);
621        // No actions should be complete yet
622        assert_eq!(actions2.len(), 0);
623
624        let actions3 = parser.parse_chunk(input_part3);
625        // The third chunk should complete the action
626        assert_eq!(actions3.len(), 1);
627        assert_eq!(
628            actions3[0],
629            EditAction::Replace {
630                file_path: "src/main.rs".to_string(),
631                old: "fn original() {}".to_string(),
632                new: "fn replacement() {}".to_string(),
633            }
634        );
635    }
636
637    #[test]
638    fn test_parser_state_preservation() {
639        let mut parser = EditActionParser::new();
640        let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n");
641
642        // Check parser is in the correct state
643        assert_eq!(parser.state, State::SearchBlock);
644        assert_eq!(parser.pre_fence_line, b"src/main.rs\n");
645
646        // Continue parsing
647        let actions2 = parser.parse_chunk("original code\n=======\n");
648        assert_eq!(parser.state, State::ReplaceBlock);
649        assert_eq!(parser.old_bytes, b"original code");
650
651        let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
652
653        // After complete parsing, state should reset
654        assert_eq!(parser.state, State::Default);
655        assert_eq!(parser.pre_fence_line, b"\n");
656        assert!(parser.old_bytes.is_empty());
657        assert!(parser.new_bytes.is_empty());
658
659        assert_eq!(actions1.len(), 0);
660        assert_eq!(actions2.len(), 0);
661        assert_eq!(actions3.len(), 1);
662    }
663
664    #[test]
665    fn test_invalid_search_marker() {
666        let input = r#"src/main.rs
667```rust
668<<<<<<< WRONG_MARKER
669fn original() {}
670=======
671fn replacement() {}
672>>>>>>> REPLACE
673```
674"#;
675
676        let mut parser = EditActionParser::new();
677        let actions = parser.parse_chunk(input);
678        assert_eq!(actions.len(), 0);
679
680        assert_eq!(parser.errors().len(), 1);
681        let error = &parser.errors()[0];
682
683        assert_eq!(error.line, 3);
684        assert_eq!(error.column, 9);
685        assert_eq!(
686            error.kind,
687            ParseErrorKind::ExpectedMarker {
688                expected: b"<<<<<<< SEARCH",
689                found: b'W'
690            }
691        );
692    }
693
694    #[test]
695    fn test_missing_closing_fence() {
696        let input = r#"src/main.rs
697```rust
698<<<<<<< SEARCH
699fn original() {}
700=======
701fn replacement() {}
702>>>>>>> REPLACE
703<!-- Missing closing fence -->
704
705src/utils.rs
706```rust
707<<<<<<< SEARCH
708fn utils_func() {}
709=======
710fn new_utils_func() {}
711>>>>>>> REPLACE
712```
713"#;
714
715        let mut parser = EditActionParser::new();
716        let actions = parser.parse_chunk(input);
717
718        // Only the second block should be parsed
719        assert_eq!(actions.len(), 1);
720        assert_eq!(
721            actions[0],
722            EditAction::Replace {
723                file_path: "src/utils.rs".to_string(),
724                old: "fn utils_func() {}".to_string(),
725                new: "fn new_utils_func() {}".to_string(),
726            }
727        );
728
729        // The parser should continue after an error
730        assert_eq!(parser.state, State::Default);
731    }
732
733    const SYSTEM_PROMPT: &str = include_str!("./edit_prompt.md");
734
735    #[test]
736    fn test_parse_examples_in_system_prompt() {
737        let mut parser = EditActionParser::new();
738        let actions = parser.parse_chunk(SYSTEM_PROMPT);
739        assert_examples_in_system_prompt(&actions, parser.errors());
740    }
741
742    #[gpui::test(iterations = 10)]
743    fn test_random_chunking_of_system_prompt(mut rng: StdRng) {
744        let mut parser = EditActionParser::new();
745        let mut remaining = SYSTEM_PROMPT;
746        let mut actions = Vec::with_capacity(5);
747
748        while !remaining.is_empty() {
749            let chunk_size = rng.gen_range(1..=std::cmp::min(remaining.len(), 100));
750
751            let (chunk, rest) = remaining.split_at(chunk_size);
752
753            actions.extend(parser.parse_chunk(chunk));
754            remaining = rest;
755        }
756
757        assert_examples_in_system_prompt(&actions, parser.errors());
758    }
759
760    fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) {
761        assert_eq!(actions.len(), 5);
762
763        assert_eq!(
764            actions[0],
765            EditAction::Replace {
766                file_path: "mathweb/flask/app.py".to_string(),
767                old: "from flask import Flask".to_string(),
768                new: "import math\nfrom flask import Flask".to_string(),
769            }
770        );
771
772        assert_eq!(
773                    actions[1],
774                    EditAction::Replace {
775                        file_path: "mathweb/flask/app.py".to_string(),
776                        old: "def factorial(n):\n    \"compute factorial\"\n\n    if n == 0:\n        return 1\n    else:\n        return n * factorial(n-1)\n".to_string(),
777                        new: "".to_string(),
778                    }
779                );
780
781        assert_eq!(
782            actions[2],
783            EditAction::Replace {
784                file_path: "mathweb/flask/app.py".to_string(),
785                old: "    return str(factorial(n))".to_string(),
786                new: "    return str(math.factorial(n))".to_string(),
787            }
788        );
789
790        assert_eq!(
791            actions[3],
792            EditAction::Write {
793                file_path: "hello.py".to_string(),
794                content: "def hello():\n    \"print a greeting\"\n\n    print(\"hello\")"
795                    .to_string(),
796            }
797        );
798
799        assert_eq!(
800            actions[4],
801            EditAction::Replace {
802                file_path: "main.py".to_string(),
803                old: "def hello():\n    \"print a greeting\"\n\n    print(\"hello\")".to_string(),
804                new: "from hello import hello".to_string(),
805            }
806        );
807
808        // The system prompt includes some text that would produce errors
809        assert_eq!(
810            errors[0].to_string(),
811            "input:102:1: Expected marker \"<<<<<<< SEARCH\", found '3'"
812        );
813        assert_eq!(
814            errors[1].to_string(),
815            "input:109:0: Expected marker \"<<<<<<< SEARCH\", found '\\n'"
816        );
817    }
818
819    #[test]
820    fn test_print_error() {
821        let input = r#"src/main.rs
822```rust
823<<<<<<< WRONG_MARKER
824fn original() {}
825=======
826fn replacement() {}
827>>>>>>> REPLACE
828```
829"#;
830
831        let mut parser = EditActionParser::new();
832        parser.parse_chunk(input);
833
834        assert_eq!(parser.errors().len(), 1);
835        let error = &parser.errors()[0];
836        let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH", found 'W'"#;
837
838        assert_eq!(format!("{}", error), expected_error);
839    }
840}