patch.go

  1package diff
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"os"
  7	"path/filepath"
  8	"strings"
  9)
 10
 11type ActionType string
 12
 13const (
 14	ActionAdd    ActionType = "add"
 15	ActionDelete ActionType = "delete"
 16	ActionUpdate ActionType = "update"
 17)
 18
 19type FileChange struct {
 20	Type       ActionType
 21	OldContent *string
 22	NewContent *string
 23	MovePath   *string
 24}
 25
 26type Commit struct {
 27	Changes map[string]FileChange
 28}
 29
 30type Chunk struct {
 31	OrigIndex int      // line index of the first line in the original file
 32	DelLines  []string // lines to delete
 33	InsLines  []string // lines to insert
 34}
 35
 36type PatchAction struct {
 37	Type     ActionType
 38	NewFile  *string
 39	Chunks   []Chunk
 40	MovePath *string
 41}
 42
 43type Patch struct {
 44	Actions map[string]PatchAction
 45}
 46
 47type DiffError struct {
 48	message string
 49}
 50
 51func (e DiffError) Error() string {
 52	return e.message
 53}
 54
 55// Helper functions for error handling
 56func NewDiffError(message string) DiffError {
 57	return DiffError{message: message}
 58}
 59
 60func fileError(action, reason, path string) DiffError {
 61	return NewDiffError(fmt.Sprintf("%s File Error: %s: %s", action, reason, path))
 62}
 63
 64func contextError(index int, context string, isEOF bool) DiffError {
 65	prefix := "Invalid Context"
 66	if isEOF {
 67		prefix = "Invalid EOF Context"
 68	}
 69	return NewDiffError(fmt.Sprintf("%s %d:\n%s", prefix, index, context))
 70}
 71
 72type Parser struct {
 73	currentFiles map[string]string
 74	lines        []string
 75	index        int
 76	patch        Patch
 77	fuzz         int
 78}
 79
 80func NewParser(currentFiles map[string]string, lines []string) *Parser {
 81	return &Parser{
 82		currentFiles: currentFiles,
 83		lines:        lines,
 84		index:        0,
 85		patch:        Patch{Actions: make(map[string]PatchAction, len(currentFiles))},
 86		fuzz:         0,
 87	}
 88}
 89
 90func (p *Parser) isDone(prefixes []string) bool {
 91	if p.index >= len(p.lines) {
 92		return true
 93	}
 94	for _, prefix := range prefixes {
 95		if strings.HasPrefix(p.lines[p.index], prefix) {
 96			return true
 97		}
 98	}
 99	return false
100}
101
102func (p *Parser) startsWith(prefix any) bool {
103	var prefixes []string
104	switch v := prefix.(type) {
105	case string:
106		prefixes = []string{v}
107	case []string:
108		prefixes = v
109	}
110
111	for _, pfx := range prefixes {
112		if strings.HasPrefix(p.lines[p.index], pfx) {
113			return true
114		}
115	}
116	return false
117}
118
119func (p *Parser) readStr(prefix string, returnEverything bool) string {
120	if p.index >= len(p.lines) {
121		return "" // Changed from panic to return empty string for safer operation
122	}
123	if strings.HasPrefix(p.lines[p.index], prefix) {
124		var text string
125		if returnEverything {
126			text = p.lines[p.index]
127		} else {
128			text = p.lines[p.index][len(prefix):]
129		}
130		p.index++
131		return text
132	}
133	return ""
134}
135
136func (p *Parser) Parse() error {
137	endPatchPrefixes := []string{"*** End Patch"}
138
139	for !p.isDone(endPatchPrefixes) {
140		path := p.readStr("*** Update File: ", false)
141		if path != "" {
142			if _, exists := p.patch.Actions[path]; exists {
143				return fileError("Update", "Duplicate Path", path)
144			}
145			moveTo := p.readStr("*** Move to: ", false)
146			if _, exists := p.currentFiles[path]; !exists {
147				return fileError("Update", "Missing File", path)
148			}
149			text := p.currentFiles[path]
150			action, err := p.parseUpdateFile(text)
151			if err != nil {
152				return err
153			}
154			if moveTo != "" {
155				action.MovePath = &moveTo
156			}
157			p.patch.Actions[path] = action
158			continue
159		}
160
161		path = p.readStr("*** Delete File: ", false)
162		if path != "" {
163			if _, exists := p.patch.Actions[path]; exists {
164				return fileError("Delete", "Duplicate Path", path)
165			}
166			if _, exists := p.currentFiles[path]; !exists {
167				return fileError("Delete", "Missing File", path)
168			}
169			p.patch.Actions[path] = PatchAction{Type: ActionDelete, Chunks: []Chunk{}}
170			continue
171		}
172
173		path = p.readStr("*** Add File: ", false)
174		if path != "" {
175			if _, exists := p.patch.Actions[path]; exists {
176				return fileError("Add", "Duplicate Path", path)
177			}
178			if _, exists := p.currentFiles[path]; exists {
179				return fileError("Add", "File already exists", path)
180			}
181			action, err := p.parseAddFile()
182			if err != nil {
183				return err
184			}
185			p.patch.Actions[path] = action
186			continue
187		}
188
189		return NewDiffError(fmt.Sprintf("Unknown Line: %s", p.lines[p.index]))
190	}
191
192	if !p.startsWith("*** End Patch") {
193		return NewDiffError("Missing End Patch")
194	}
195	p.index++
196
197	return nil
198}
199
200func (p *Parser) parseUpdateFile(text string) (PatchAction, error) {
201	action := PatchAction{Type: ActionUpdate, Chunks: []Chunk{}}
202	fileLines := strings.Split(text, "\n")
203	index := 0
204
205	endPrefixes := []string{
206		"*** End Patch",
207		"*** Update File:",
208		"*** Delete File:",
209		"*** Add File:",
210		"*** End of File",
211	}
212
213	for !p.isDone(endPrefixes) {
214		defStr := p.readStr("@@ ", false)
215		sectionStr := ""
216		if defStr == "" && p.index < len(p.lines) && p.lines[p.index] == "@@" {
217			sectionStr = p.lines[p.index]
218			p.index++
219		}
220		if defStr == "" && sectionStr == "" && index != 0 {
221			return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index]))
222		}
223		if strings.TrimSpace(defStr) != "" {
224			found := false
225			for i := range fileLines[:index] {
226				if fileLines[i] == defStr {
227					found = true
228					break
229				}
230			}
231
232			if !found {
233				for i := index; i < len(fileLines); i++ {
234					if fileLines[i] == defStr {
235						index = i + 1
236						found = true
237						break
238					}
239				}
240			}
241
242			if !found {
243				for i := range fileLines[:index] {
244					if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) {
245						found = true
246						break
247					}
248				}
249			}
250
251			if !found {
252				for i := index; i < len(fileLines); i++ {
253					if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) {
254						index = i + 1
255						p.fuzz++
256						found = true
257						break
258					}
259				}
260			}
261		}
262
263		nextChunkContext, chunks, endPatchIndex, eof := peekNextSection(p.lines, p.index)
264		newIndex, fuzz := findContext(fileLines, nextChunkContext, index, eof)
265		if newIndex == -1 {
266			ctxText := strings.Join(nextChunkContext, "\n")
267			return action, contextError(index, ctxText, eof)
268		}
269		p.fuzz += fuzz
270
271		for _, ch := range chunks {
272			ch.OrigIndex += newIndex
273			action.Chunks = append(action.Chunks, ch)
274		}
275		index = newIndex + len(nextChunkContext)
276		p.index = endPatchIndex
277	}
278	return action, nil
279}
280
281func (p *Parser) parseAddFile() (PatchAction, error) {
282	lines := make([]string, 0, 16) // Preallocate space for better performance
283	endPrefixes := []string{
284		"*** End Patch",
285		"*** Update File:",
286		"*** Delete File:",
287		"*** Add File:",
288	}
289
290	for !p.isDone(endPrefixes) {
291		s := p.readStr("", true)
292		if !strings.HasPrefix(s, "+") {
293			return PatchAction{}, NewDiffError(fmt.Sprintf("Invalid Add File Line: %s", s))
294		}
295		lines = append(lines, s[1:])
296	}
297
298	newFile := strings.Join(lines, "\n")
299	return PatchAction{
300		Type:    ActionAdd,
301		NewFile: &newFile,
302		Chunks:  []Chunk{},
303	}, nil
304}
305
306// Refactored to use a matcher function for each comparison type
307func findContextCore(lines []string, context []string, start int) (int, int) {
308	if len(context) == 0 {
309		return start, 0
310	}
311
312	// Try exact match
313	if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool {
314		return a == b
315	}); idx >= 0 {
316		return idx, fuzz
317	}
318
319	// Try trimming right whitespace
320	if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool {
321		return strings.TrimRight(a, " \t") == strings.TrimRight(b, " \t")
322	}); idx >= 0 {
323		return idx, fuzz
324	}
325
326	// Try trimming all whitespace
327	if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool {
328		return strings.TrimSpace(a) == strings.TrimSpace(b)
329	}); idx >= 0 {
330		return idx, fuzz
331	}
332
333	return -1, 0
334}
335
336// Helper function to DRY up the match logic
337func tryFindMatch(lines []string, context []string, start int,
338	compareFunc func(string, string) bool,
339) (int, int) {
340	for i := start; i < len(lines); i++ {
341		if i+len(context) <= len(lines) {
342			match := true
343			for j := range context {
344				if !compareFunc(lines[i+j], context[j]) {
345					match = false
346					break
347				}
348			}
349			if match {
350				// Return fuzz level: 0 for exact, 1 for trimRight, 100 for trimSpace
351				var fuzz int
352				if compareFunc("a ", "a") && !compareFunc("a", "b") {
353					fuzz = 1
354				} else if compareFunc("a  ", "a") {
355					fuzz = 100
356				}
357				return i, fuzz
358			}
359		}
360	}
361	return -1, 0
362}
363
364func findContext(lines []string, context []string, start int, eof bool) (int, int) {
365	if eof {
366		newIndex, fuzz := findContextCore(lines, context, len(lines)-len(context))
367		if newIndex != -1 {
368			return newIndex, fuzz
369		}
370		newIndex, fuzz = findContextCore(lines, context, start)
371		return newIndex, fuzz + 10000
372	}
373	return findContextCore(lines, context, start)
374}
375
376func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, bool) {
377	index := initialIndex
378	old := make([]string, 0, 32) // Preallocate for better performance
379	delLines := make([]string, 0, 8)
380	insLines := make([]string, 0, 8)
381	chunks := make([]Chunk, 0, 4)
382	mode := "keep"
383
384	// End conditions for the section
385	endSectionConditions := func(s string) bool {
386		return strings.HasPrefix(s, "@@") ||
387			strings.HasPrefix(s, "*** End Patch") ||
388			strings.HasPrefix(s, "*** Update File:") ||
389			strings.HasPrefix(s, "*** Delete File:") ||
390			strings.HasPrefix(s, "*** Add File:") ||
391			strings.HasPrefix(s, "*** End of File") ||
392			s == "***" ||
393			strings.HasPrefix(s, "***")
394	}
395
396	for index < len(lines) {
397		s := lines[index]
398		if endSectionConditions(s) {
399			break
400		}
401		index++
402		lastMode := mode
403		line := s
404
405		if len(line) > 0 {
406			switch line[0] {
407			case '+':
408				mode = "add"
409			case '-':
410				mode = "delete"
411			case ' ':
412				mode = "keep"
413			default:
414				mode = "keep"
415				line = " " + line
416			}
417		} else {
418			mode = "keep"
419			line = " "
420		}
421
422		line = line[1:]
423		if mode == "keep" && lastMode != mode {
424			if len(insLines) > 0 || len(delLines) > 0 {
425				chunks = append(chunks, Chunk{
426					OrigIndex: len(old) - len(delLines),
427					DelLines:  delLines,
428					InsLines:  insLines,
429				})
430			}
431			delLines = make([]string, 0, 8)
432			insLines = make([]string, 0, 8)
433		}
434		switch mode {
435		case "delete":
436			delLines = append(delLines, line)
437			old = append(old, line)
438		case "add":
439			insLines = append(insLines, line)
440		default:
441			old = append(old, line)
442		}
443	}
444
445	if len(insLines) > 0 || len(delLines) > 0 {
446		chunks = append(chunks, Chunk{
447			OrigIndex: len(old) - len(delLines),
448			DelLines:  delLines,
449			InsLines:  insLines,
450		})
451	}
452
453	if index < len(lines) && lines[index] == "*** End of File" {
454		index++
455		return old, chunks, index, true
456	}
457	return old, chunks, index, false
458}
459
460func TextToPatch(text string, orig map[string]string) (Patch, int, error) {
461	text = strings.TrimSpace(text)
462	lines := strings.Split(text, "\n")
463	if len(lines) < 2 || !strings.HasPrefix(lines[0], "*** Begin Patch") || lines[len(lines)-1] != "*** End Patch" {
464		return Patch{}, 0, NewDiffError("Invalid patch text")
465	}
466	parser := NewParser(orig, lines)
467	parser.index = 1
468	if err := parser.Parse(); err != nil {
469		return Patch{}, 0, err
470	}
471	return parser.patch, parser.fuzz, nil
472}
473
474func IdentifyFilesNeeded(text string) []string {
475	text = strings.TrimSpace(text)
476	lines := strings.Split(text, "\n")
477	result := make(map[string]bool)
478
479	for _, line := range lines {
480		if strings.HasPrefix(line, "*** Update File: ") {
481			result[line[len("*** Update File: "):]] = true
482		}
483		if strings.HasPrefix(line, "*** Delete File: ") {
484			result[line[len("*** Delete File: "):]] = true
485		}
486	}
487
488	files := make([]string, 0, len(result))
489	for file := range result {
490		files = append(files, file)
491	}
492	return files
493}
494
495func IdentifyFilesAdded(text string) []string {
496	text = strings.TrimSpace(text)
497	lines := strings.Split(text, "\n")
498	result := make(map[string]bool)
499
500	for _, line := range lines {
501		if strings.HasPrefix(line, "*** Add File: ") {
502			result[line[len("*** Add File: "):]] = true
503		}
504	}
505
506	files := make([]string, 0, len(result))
507	for file := range result {
508		files = append(files, file)
509	}
510	return files
511}
512
513func getUpdatedFile(text string, action PatchAction, path string) (string, error) {
514	if action.Type != ActionUpdate {
515		return "", errors.New("expected UPDATE action")
516	}
517	origLines := strings.Split(text, "\n")
518	destLines := make([]string, 0, len(origLines)) // Preallocate with capacity
519	origIndex := 0
520
521	for _, chunk := range action.Chunks {
522		if chunk.OrigIndex > len(origLines) {
523			return "", NewDiffError(fmt.Sprintf("%s: chunk.orig_index %d > len(lines) %d", path, chunk.OrigIndex, len(origLines)))
524		}
525		if origIndex > chunk.OrigIndex {
526			return "", NewDiffError(fmt.Sprintf("%s: orig_index %d > chunk.orig_index %d", path, origIndex, chunk.OrigIndex))
527		}
528		destLines = append(destLines, origLines[origIndex:chunk.OrigIndex]...)
529		delta := chunk.OrigIndex - origIndex
530		origIndex += delta
531
532		if len(chunk.InsLines) > 0 {
533			destLines = append(destLines, chunk.InsLines...)
534		}
535		origIndex += len(chunk.DelLines)
536	}
537
538	destLines = append(destLines, origLines[origIndex:]...)
539	return strings.Join(destLines, "\n"), nil
540}
541
542func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) {
543	commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))}
544	for pathKey, action := range patch.Actions {
545		switch action.Type {
546		case ActionDelete:
547			oldContent := orig[pathKey]
548			commit.Changes[pathKey] = FileChange{
549				Type:       ActionDelete,
550				OldContent: &oldContent,
551			}
552		case ActionAdd:
553			commit.Changes[pathKey] = FileChange{
554				Type:       ActionAdd,
555				NewContent: action.NewFile,
556			}
557		case ActionUpdate:
558			newContent, err := getUpdatedFile(orig[pathKey], action, pathKey)
559			if err != nil {
560				return Commit{}, err
561			}
562			oldContent := orig[pathKey]
563			fileChange := FileChange{
564				Type:       ActionUpdate,
565				OldContent: &oldContent,
566				NewContent: &newContent,
567			}
568			if action.MovePath != nil {
569				fileChange.MovePath = action.MovePath
570			}
571			commit.Changes[pathKey] = fileChange
572		}
573	}
574	return commit, nil
575}
576
577func AssembleChanges(orig map[string]string, updatedFiles map[string]string) Commit {
578	commit := Commit{Changes: make(map[string]FileChange, len(updatedFiles))}
579	for p, newContent := range updatedFiles {
580		oldContent, exists := orig[p]
581		if exists && oldContent == newContent {
582			continue
583		}
584
585		if exists && newContent != "" {
586			commit.Changes[p] = FileChange{
587				Type:       ActionUpdate,
588				OldContent: &oldContent,
589				NewContent: &newContent,
590			}
591		} else if newContent != "" {
592			commit.Changes[p] = FileChange{
593				Type:       ActionAdd,
594				NewContent: &newContent,
595			}
596		} else if exists {
597			commit.Changes[p] = FileChange{
598				Type:       ActionDelete,
599				OldContent: &oldContent,
600			}
601		} else {
602			return commit // Changed from panic to simply return current commit
603		}
604	}
605	return commit
606}
607
608func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string]string, error) {
609	orig := make(map[string]string, len(paths))
610	for _, p := range paths {
611		content, err := openFn(p)
612		if err != nil {
613			return nil, fileError("Open", "File not found", p)
614		}
615		orig[p] = content
616	}
617	return orig, nil
618}
619
620func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error {
621	for p, change := range commit.Changes {
622		switch change.Type {
623		case ActionDelete:
624			if err := removeFn(p); err != nil {
625				return err
626			}
627		case ActionAdd:
628			if change.NewContent == nil {
629				return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p))
630			}
631			if err := writeFn(p, *change.NewContent); err != nil {
632				return err
633			}
634		case ActionUpdate:
635			if change.NewContent == nil {
636				return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p))
637			}
638			if change.MovePath != nil {
639				if err := writeFn(*change.MovePath, *change.NewContent); err != nil {
640					return err
641				}
642				if err := removeFn(p); err != nil {
643					return err
644				}
645			} else {
646				if err := writeFn(p, *change.NewContent); err != nil {
647					return err
648				}
649			}
650		}
651	}
652	return nil
653}
654
655func ProcessPatch(text string, openFn func(string) (string, error), writeFn func(string, string) error, removeFn func(string) error) (string, error) {
656	if !strings.HasPrefix(text, "*** Begin Patch") {
657		return "", NewDiffError("Patch must start with *** Begin Patch")
658	}
659	paths := IdentifyFilesNeeded(text)
660	orig, err := LoadFiles(paths, openFn)
661	if err != nil {
662		return "", err
663	}
664
665	patch, fuzz, err := TextToPatch(text, orig)
666	if err != nil {
667		return "", err
668	}
669
670	if fuzz > 0 {
671		return "", NewDiffError(fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz))
672	}
673
674	commit, err := PatchToCommit(patch, orig)
675	if err != nil {
676		return "", err
677	}
678
679	if err := ApplyCommit(commit, writeFn, removeFn); err != nil {
680		return "", err
681	}
682
683	return "Patch applied successfully", nil
684}
685
686func OpenFile(p string) (string, error) {
687	data, err := os.ReadFile(p)
688	if err != nil {
689		return "", err
690	}
691	return string(data), nil
692}
693
694func WriteFile(p string, content string) error {
695	if filepath.IsAbs(p) {
696		return NewDiffError("We do not support absolute paths.")
697	}
698
699	dir := filepath.Dir(p)
700	if dir != "." {
701		if err := os.MkdirAll(dir, 0o755); err != nil {
702			return err
703		}
704	}
705
706	return os.WriteFile(p, []byte(content), 0o644)
707}
708
709func RemoveFile(p string) error {
710	return os.Remove(p)
711}
712
713func ValidatePatch(patchText string, files map[string]string) (bool, string, error) {
714	if !strings.HasPrefix(patchText, "*** Begin Patch") {
715		return false, "Patch must start with *** Begin Patch", nil
716	}
717
718	neededFiles := IdentifyFilesNeeded(patchText)
719	for _, filePath := range neededFiles {
720		if _, exists := files[filePath]; !exists {
721			return false, fmt.Sprintf("File not found: %s", filePath), nil
722		}
723	}
724
725	patch, fuzz, err := TextToPatch(patchText, files)
726	if err != nil {
727		return false, err.Error(), nil
728	}
729
730	if fuzz > 0 {
731		return false, fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz), nil
732	}
733
734	_, err = PatchToCommit(patch, files)
735	if err != nil {
736		return false, err.Error(), nil
737	}
738
739	return true, "Patch is valid", nil
740}