patch.go

  1package diff
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"os"
  7	"path/filepath"
  8	"strings"
  9)
 10
 11type ActionType string
 12
 13const (
 14	ActionAdd    ActionType = "add"
 15	ActionDelete ActionType = "delete"
 16	ActionUpdate ActionType = "update"
 17)
 18
 19type FileChange struct {
 20	Type       ActionType
 21	OldContent *string
 22	NewContent *string
 23	MovePath   *string
 24}
 25
 26type Commit struct {
 27	Changes map[string]FileChange
 28}
 29
 30type Chunk struct {
 31	OrigIndex int      // line index of the first line in the original file
 32	DelLines  []string // lines to delete
 33	InsLines  []string // lines to insert
 34}
 35
 36type PatchAction struct {
 37	Type     ActionType
 38	NewFile  *string
 39	Chunks   []Chunk
 40	MovePath *string
 41}
 42
 43type Patch struct {
 44	Actions map[string]PatchAction
 45}
 46
 47type DiffError struct {
 48	message string
 49}
 50
 51func (e DiffError) Error() string {
 52	return e.message
 53}
 54
 55// Helper functions for error handling
 56func NewDiffError(message string) DiffError {
 57	return DiffError{message: message}
 58}
 59
 60func fileError(action, reason, path string) DiffError {
 61	return NewDiffError(fmt.Sprintf("%s File Error: %s: %s", action, reason, path))
 62}
 63
 64func contextError(index int, context string, isEOF bool) DiffError {
 65	prefix := "Invalid Context"
 66	if isEOF {
 67		prefix = "Invalid EOF Context"
 68	}
 69	return NewDiffError(fmt.Sprintf("%s %d:\n%s", prefix, index, context))
 70}
 71
 72type Parser struct {
 73	currentFiles map[string]string
 74	lines        []string
 75	index        int
 76	patch        Patch
 77	fuzz         int
 78}
 79
 80func NewParser(currentFiles map[string]string, lines []string) *Parser {
 81	return &Parser{
 82		currentFiles: currentFiles,
 83		lines:        lines,
 84		index:        0,
 85		patch:        Patch{Actions: make(map[string]PatchAction, len(currentFiles))},
 86		fuzz:         0,
 87	}
 88}
 89
 90func (p *Parser) isDone(prefixes []string) bool {
 91	if p.index >= len(p.lines) {
 92		return true
 93	}
 94	if prefixes != nil {
 95		for _, prefix := range prefixes {
 96			if strings.HasPrefix(p.lines[p.index], prefix) {
 97				return true
 98			}
 99		}
100	}
101	return false
102}
103
104func (p *Parser) startsWith(prefix any) bool {
105	var prefixes []string
106	switch v := prefix.(type) {
107	case string:
108		prefixes = []string{v}
109	case []string:
110		prefixes = v
111	}
112
113	for _, pfx := range prefixes {
114		if strings.HasPrefix(p.lines[p.index], pfx) {
115			return true
116		}
117	}
118	return false
119}
120
121func (p *Parser) readStr(prefix string, returnEverything bool) string {
122	if p.index >= len(p.lines) {
123		return "" // Changed from panic to return empty string for safer operation
124	}
125	if strings.HasPrefix(p.lines[p.index], prefix) {
126		var text string
127		if returnEverything {
128			text = p.lines[p.index]
129		} else {
130			text = p.lines[p.index][len(prefix):]
131		}
132		p.index++
133		return text
134	}
135	return ""
136}
137
138func (p *Parser) Parse() error {
139	endPatchPrefixes := []string{"*** End Patch"}
140
141	for !p.isDone(endPatchPrefixes) {
142		path := p.readStr("*** Update File: ", false)
143		if path != "" {
144			if _, exists := p.patch.Actions[path]; exists {
145				return fileError("Update", "Duplicate Path", path)
146			}
147			moveTo := p.readStr("*** Move to: ", false)
148			if _, exists := p.currentFiles[path]; !exists {
149				return fileError("Update", "Missing File", path)
150			}
151			text := p.currentFiles[path]
152			action, err := p.parseUpdateFile(text)
153			if err != nil {
154				return err
155			}
156			if moveTo != "" {
157				action.MovePath = &moveTo
158			}
159			p.patch.Actions[path] = action
160			continue
161		}
162
163		path = p.readStr("*** Delete File: ", false)
164		if path != "" {
165			if _, exists := p.patch.Actions[path]; exists {
166				return fileError("Delete", "Duplicate Path", path)
167			}
168			if _, exists := p.currentFiles[path]; !exists {
169				return fileError("Delete", "Missing File", path)
170			}
171			p.patch.Actions[path] = PatchAction{Type: ActionDelete, Chunks: []Chunk{}}
172			continue
173		}
174
175		path = p.readStr("*** Add File: ", false)
176		if path != "" {
177			if _, exists := p.patch.Actions[path]; exists {
178				return fileError("Add", "Duplicate Path", path)
179			}
180			if _, exists := p.currentFiles[path]; exists {
181				return fileError("Add", "File already exists", path)
182			}
183			action, err := p.parseAddFile()
184			if err != nil {
185				return err
186			}
187			p.patch.Actions[path] = action
188			continue
189		}
190
191		return NewDiffError(fmt.Sprintf("Unknown Line: %s", p.lines[p.index]))
192	}
193
194	if !p.startsWith("*** End Patch") {
195		return NewDiffError("Missing End Patch")
196	}
197	p.index++
198
199	return nil
200}
201
202func (p *Parser) parseUpdateFile(text string) (PatchAction, error) {
203	action := PatchAction{Type: ActionUpdate, Chunks: []Chunk{}}
204	fileLines := strings.Split(text, "\n")
205	index := 0
206
207	endPrefixes := []string{
208		"*** End Patch",
209		"*** Update File:",
210		"*** Delete File:",
211		"*** Add File:",
212		"*** End of File",
213	}
214
215	for !p.isDone(endPrefixes) {
216		defStr := p.readStr("@@ ", false)
217		sectionStr := ""
218		if defStr == "" && p.index < len(p.lines) && p.lines[p.index] == "@@" {
219			sectionStr = p.lines[p.index]
220			p.index++
221		}
222		if !(defStr != "" || sectionStr != "" || index == 0) {
223			return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index]))
224		}
225		if strings.TrimSpace(defStr) != "" {
226			found := false
227			for i := range fileLines[:index] {
228				if fileLines[i] == defStr {
229					found = true
230					break
231				}
232			}
233
234			if !found {
235				for i := index; i < len(fileLines); i++ {
236					if fileLines[i] == defStr {
237						index = i + 1
238						found = true
239						break
240					}
241				}
242			}
243
244			if !found {
245				for i := range fileLines[:index] {
246					if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) {
247						found = true
248						break
249					}
250				}
251			}
252
253			if !found {
254				for i := index; i < len(fileLines); i++ {
255					if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) {
256						index = i + 1
257						p.fuzz++
258						found = true
259						break
260					}
261				}
262			}
263		}
264
265		nextChunkContext, chunks, endPatchIndex, eof := peekNextSection(p.lines, p.index)
266		newIndex, fuzz := findContext(fileLines, nextChunkContext, index, eof)
267		if newIndex == -1 {
268			ctxText := strings.Join(nextChunkContext, "\n")
269			return action, contextError(index, ctxText, eof)
270		}
271		p.fuzz += fuzz
272
273		for _, ch := range chunks {
274			ch.OrigIndex += newIndex
275			action.Chunks = append(action.Chunks, ch)
276		}
277		index = newIndex + len(nextChunkContext)
278		p.index = endPatchIndex
279	}
280	return action, nil
281}
282
283func (p *Parser) parseAddFile() (PatchAction, error) {
284	lines := make([]string, 0, 16) // Preallocate space for better performance
285	endPrefixes := []string{
286		"*** End Patch",
287		"*** Update File:",
288		"*** Delete File:",
289		"*** Add File:",
290	}
291
292	for !p.isDone(endPrefixes) {
293		s := p.readStr("", true)
294		if !strings.HasPrefix(s, "+") {
295			return PatchAction{}, NewDiffError(fmt.Sprintf("Invalid Add File Line: %s", s))
296		}
297		lines = append(lines, s[1:])
298	}
299
300	newFile := strings.Join(lines, "\n")
301	return PatchAction{
302		Type:    ActionAdd,
303		NewFile: &newFile,
304		Chunks:  []Chunk{},
305	}, nil
306}
307
308// Refactored to use a matcher function for each comparison type
309func findContextCore(lines []string, context []string, start int) (int, int) {
310	if len(context) == 0 {
311		return start, 0
312	}
313
314	// Try exact match
315	if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool {
316		return a == b
317	}); idx >= 0 {
318		return idx, fuzz
319	}
320
321	// Try trimming right whitespace
322	if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool {
323		return strings.TrimRight(a, " \t") == strings.TrimRight(b, " \t")
324	}); idx >= 0 {
325		return idx, fuzz
326	}
327
328	// Try trimming all whitespace
329	if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool {
330		return strings.TrimSpace(a) == strings.TrimSpace(b)
331	}); idx >= 0 {
332		return idx, fuzz
333	}
334
335	return -1, 0
336}
337
338// Helper function to DRY up the match logic
339func tryFindMatch(lines []string, context []string, start int,
340	compareFunc func(string, string) bool,
341) (int, int) {
342	for i := start; i < len(lines); i++ {
343		if i+len(context) <= len(lines) {
344			match := true
345			for j := range context {
346				if !compareFunc(lines[i+j], context[j]) {
347					match = false
348					break
349				}
350			}
351			if match {
352				// Return fuzz level: 0 for exact, 1 for trimRight, 100 for trimSpace
353				var fuzz int
354				if compareFunc("a ", "a") && !compareFunc("a", "b") {
355					fuzz = 1
356				} else if compareFunc("a  ", "a") {
357					fuzz = 100
358				}
359				return i, fuzz
360			}
361		}
362	}
363	return -1, 0
364}
365
366func findContext(lines []string, context []string, start int, eof bool) (int, int) {
367	if eof {
368		newIndex, fuzz := findContextCore(lines, context, len(lines)-len(context))
369		if newIndex != -1 {
370			return newIndex, fuzz
371		}
372		newIndex, fuzz = findContextCore(lines, context, start)
373		return newIndex, fuzz + 10000
374	}
375	return findContextCore(lines, context, start)
376}
377
378func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, bool) {
379	index := initialIndex
380	old := make([]string, 0, 32) // Preallocate for better performance
381	delLines := make([]string, 0, 8)
382	insLines := make([]string, 0, 8)
383	chunks := make([]Chunk, 0, 4)
384	mode := "keep"
385
386	// End conditions for the section
387	endSectionConditions := func(s string) bool {
388		return strings.HasPrefix(s, "@@") ||
389			strings.HasPrefix(s, "*** End Patch") ||
390			strings.HasPrefix(s, "*** Update File:") ||
391			strings.HasPrefix(s, "*** Delete File:") ||
392			strings.HasPrefix(s, "*** Add File:") ||
393			strings.HasPrefix(s, "*** End of File") ||
394			s == "***" ||
395			strings.HasPrefix(s, "***")
396	}
397
398	for index < len(lines) {
399		s := lines[index]
400		if endSectionConditions(s) {
401			break
402		}
403		index++
404		lastMode := mode
405		line := s
406
407		if len(line) > 0 {
408			switch line[0] {
409			case '+':
410				mode = "add"
411			case '-':
412				mode = "delete"
413			case ' ':
414				mode = "keep"
415			default:
416				mode = "keep"
417				line = " " + line
418			}
419		} else {
420			mode = "keep"
421			line = " "
422		}
423
424		line = line[1:]
425		if mode == "keep" && lastMode != mode {
426			if len(insLines) > 0 || len(delLines) > 0 {
427				chunks = append(chunks, Chunk{
428					OrigIndex: len(old) - len(delLines),
429					DelLines:  delLines,
430					InsLines:  insLines,
431				})
432			}
433			delLines = make([]string, 0, 8)
434			insLines = make([]string, 0, 8)
435		}
436		if mode == "delete" {
437			delLines = append(delLines, line)
438			old = append(old, line)
439		} else if mode == "add" {
440			insLines = append(insLines, line)
441		} else {
442			old = append(old, line)
443		}
444	}
445
446	if len(insLines) > 0 || len(delLines) > 0 {
447		chunks = append(chunks, Chunk{
448			OrigIndex: len(old) - len(delLines),
449			DelLines:  delLines,
450			InsLines:  insLines,
451		})
452	}
453
454	if index < len(lines) && lines[index] == "*** End of File" {
455		index++
456		return old, chunks, index, true
457	}
458	return old, chunks, index, false
459}
460
461func TextToPatch(text string, orig map[string]string) (Patch, int, error) {
462	text = strings.TrimSpace(text)
463	lines := strings.Split(text, "\n")
464	if len(lines) < 2 || !strings.HasPrefix(lines[0], "*** Begin Patch") || lines[len(lines)-1] != "*** End Patch" {
465		return Patch{}, 0, NewDiffError("Invalid patch text")
466	}
467	parser := NewParser(orig, lines)
468	parser.index = 1
469	if err := parser.Parse(); err != nil {
470		return Patch{}, 0, err
471	}
472	return parser.patch, parser.fuzz, nil
473}
474
475func IdentifyFilesNeeded(text string) []string {
476	text = strings.TrimSpace(text)
477	lines := strings.Split(text, "\n")
478	result := make(map[string]bool)
479
480	for _, line := range lines {
481		if strings.HasPrefix(line, "*** Update File: ") {
482			result[line[len("*** Update File: "):]] = true
483		}
484		if strings.HasPrefix(line, "*** Delete File: ") {
485			result[line[len("*** Delete File: "):]] = true
486		}
487	}
488
489	files := make([]string, 0, len(result))
490	for file := range result {
491		files = append(files, file)
492	}
493	return files
494}
495
496func IdentifyFilesAdded(text string) []string {
497	text = strings.TrimSpace(text)
498	lines := strings.Split(text, "\n")
499	result := make(map[string]bool)
500
501	for _, line := range lines {
502		if strings.HasPrefix(line, "*** Add File: ") {
503			result[line[len("*** Add File: "):]] = true
504		}
505	}
506
507	files := make([]string, 0, len(result))
508	for file := range result {
509		files = append(files, file)
510	}
511	return files
512}
513
514func getUpdatedFile(text string, action PatchAction, path string) (string, error) {
515	if action.Type != ActionUpdate {
516		return "", errors.New("Expected UPDATE action")
517	}
518	origLines := strings.Split(text, "\n")
519	destLines := make([]string, 0, len(origLines)) // Preallocate with capacity
520	origIndex := 0
521
522	for _, chunk := range action.Chunks {
523		if chunk.OrigIndex > len(origLines) {
524			return "", NewDiffError(fmt.Sprintf("%s: chunk.orig_index %d > len(lines) %d", path, chunk.OrigIndex, len(origLines)))
525		}
526		if origIndex > chunk.OrigIndex {
527			return "", NewDiffError(fmt.Sprintf("%s: orig_index %d > chunk.orig_index %d", path, origIndex, chunk.OrigIndex))
528		}
529		destLines = append(destLines, origLines[origIndex:chunk.OrigIndex]...)
530		delta := chunk.OrigIndex - origIndex
531		origIndex += delta
532
533		if len(chunk.InsLines) > 0 {
534			destLines = append(destLines, chunk.InsLines...)
535		}
536		origIndex += len(chunk.DelLines)
537	}
538
539	destLines = append(destLines, origLines[origIndex:]...)
540	return strings.Join(destLines, "\n"), nil
541}
542
543func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) {
544	commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))}
545	for pathKey, action := range patch.Actions {
546		if action.Type == ActionDelete {
547			oldContent := orig[pathKey]
548			commit.Changes[pathKey] = FileChange{
549				Type:       ActionDelete,
550				OldContent: &oldContent,
551			}
552		} else if action.Type == ActionAdd {
553			commit.Changes[pathKey] = FileChange{
554				Type:       ActionAdd,
555				NewContent: action.NewFile,
556			}
557		} else if action.Type == ActionUpdate {
558			newContent, err := getUpdatedFile(orig[pathKey], action, pathKey)
559			if err != nil {
560				return Commit{}, err
561			}
562			oldContent := orig[pathKey]
563			fileChange := FileChange{
564				Type:       ActionUpdate,
565				OldContent: &oldContent,
566				NewContent: &newContent,
567			}
568			if action.MovePath != nil {
569				fileChange.MovePath = action.MovePath
570			}
571			commit.Changes[pathKey] = fileChange
572		}
573	}
574	return commit, nil
575}
576
577func AssembleChanges(orig map[string]string, updatedFiles map[string]string) Commit {
578	commit := Commit{Changes: make(map[string]FileChange, len(updatedFiles))}
579	for p, newContent := range updatedFiles {
580		oldContent, exists := orig[p]
581		if exists && oldContent == newContent {
582			continue
583		}
584
585		if exists && newContent != "" {
586			commit.Changes[p] = FileChange{
587				Type:       ActionUpdate,
588				OldContent: &oldContent,
589				NewContent: &newContent,
590			}
591		} else if newContent != "" {
592			commit.Changes[p] = FileChange{
593				Type:       ActionAdd,
594				NewContent: &newContent,
595			}
596		} else if exists {
597			commit.Changes[p] = FileChange{
598				Type:       ActionDelete,
599				OldContent: &oldContent,
600			}
601		} else {
602			return commit // Changed from panic to simply return current commit
603		}
604	}
605	return commit
606}
607
608func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string]string, error) {
609	orig := make(map[string]string, len(paths))
610	for _, p := range paths {
611		content, err := openFn(p)
612		if err != nil {
613			return nil, fileError("Open", "File not found", p)
614		}
615		orig[p] = content
616	}
617	return orig, nil
618}
619
620func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error {
621	for p, change := range commit.Changes {
622		if change.Type == ActionDelete {
623			if err := removeFn(p); err != nil {
624				return err
625			}
626		} else if change.Type == ActionAdd {
627			if change.NewContent == nil {
628				return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p))
629			}
630			if err := writeFn(p, *change.NewContent); err != nil {
631				return err
632			}
633		} else if change.Type == ActionUpdate {
634			if change.NewContent == nil {
635				return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p))
636			}
637			if change.MovePath != nil {
638				if err := writeFn(*change.MovePath, *change.NewContent); err != nil {
639					return err
640				}
641				if err := removeFn(p); err != nil {
642					return err
643				}
644			} else {
645				if err := writeFn(p, *change.NewContent); err != nil {
646					return err
647				}
648			}
649		}
650	}
651	return nil
652}
653
654func ProcessPatch(text string, openFn func(string) (string, error), writeFn func(string, string) error, removeFn func(string) error) (string, error) {
655	if !strings.HasPrefix(text, "*** Begin Patch") {
656		return "", NewDiffError("Patch must start with *** Begin Patch")
657	}
658	paths := IdentifyFilesNeeded(text)
659	orig, err := LoadFiles(paths, openFn)
660	if err != nil {
661		return "", err
662	}
663
664	patch, fuzz, err := TextToPatch(text, orig)
665	if err != nil {
666		return "", err
667	}
668
669	if fuzz > 0 {
670		return "", NewDiffError(fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz))
671	}
672
673	commit, err := PatchToCommit(patch, orig)
674	if err != nil {
675		return "", err
676	}
677
678	if err := ApplyCommit(commit, writeFn, removeFn); err != nil {
679		return "", err
680	}
681
682	return "Patch applied successfully", nil
683}
684
685func OpenFile(p string) (string, error) {
686	data, err := os.ReadFile(p)
687	if err != nil {
688		return "", err
689	}
690	return string(data), nil
691}
692
693func WriteFile(p string, content string) error {
694	if filepath.IsAbs(p) {
695		return NewDiffError("We do not support absolute paths.")
696	}
697
698	dir := filepath.Dir(p)
699	if dir != "." {
700		if err := os.MkdirAll(dir, 0o755); err != nil {
701			return err
702		}
703	}
704
705	return os.WriteFile(p, []byte(content), 0o644)
706}
707
708func RemoveFile(p string) error {
709	return os.Remove(p)
710}
711
712func ValidatePatch(patchText string, files map[string]string) (bool, string, error) {
713	if !strings.HasPrefix(patchText, "*** Begin Patch") {
714		return false, "Patch must start with *** Begin Patch", nil
715	}
716
717	neededFiles := IdentifyFilesNeeded(patchText)
718	for _, filePath := range neededFiles {
719		if _, exists := files[filePath]; !exists {
720			return false, fmt.Sprintf("File not found: %s", filePath), nil
721		}
722	}
723
724	patch, fuzz, err := TextToPatch(patchText, files)
725	if err != nil {
726		return false, err.Error(), nil
727	}
728
729	if fuzz > 0 {
730		return false, fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz), nil
731	}
732
733	_, err = PatchToCommit(patch, files)
734	if err != nil {
735		return false, err.Error(), nil
736	}
737
738	return true, "Patch is valid", nil
739}