1package tools
2
3import (
4 "context"
5 "os"
6 "path/filepath"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/csync"
10 "github.com/charmbracelet/crush/internal/filetracker"
11 "github.com/charmbracelet/crush/internal/history"
12 "github.com/charmbracelet/crush/internal/lsp"
13 "github.com/charmbracelet/crush/internal/permission"
14 "github.com/charmbracelet/crush/internal/pubsub"
15 "github.com/stretchr/testify/require"
16)
17
18type mockPermissionService struct {
19 *pubsub.Broker[permission.PermissionRequest]
20}
21
22func (m *mockPermissionService) Request(req permission.CreatePermissionRequest) bool {
23 return true
24}
25
26func (m *mockPermissionService) Grant(req permission.PermissionRequest) {}
27
28func (m *mockPermissionService) Deny(req permission.PermissionRequest) {}
29
30func (m *mockPermissionService) GrantPersistent(req permission.PermissionRequest) {}
31
32func (m *mockPermissionService) AutoApproveSession(sessionID string) {}
33
34func (m *mockPermissionService) SetSkipRequests(skip bool) {}
35
36func (m *mockPermissionService) SkipRequests() bool {
37 return false
38}
39
40func (m *mockPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] {
41 return make(<-chan pubsub.Event[permission.PermissionNotification])
42}
43
44type mockHistoryService struct {
45 *pubsub.Broker[history.File]
46}
47
48func (m *mockHistoryService) Create(ctx context.Context, sessionID, path, content string) (history.File, error) {
49 return history.File{Path: path, Content: content}, nil
50}
51
52func (m *mockHistoryService) CreateVersion(ctx context.Context, sessionID, path, content string) (history.File, error) {
53 return history.File{}, nil
54}
55
56func (m *mockHistoryService) GetByPathAndSession(ctx context.Context, path, sessionID string) (history.File, error) {
57 return history.File{Path: path, Content: ""}, nil
58}
59
60func (m *mockHistoryService) Get(ctx context.Context, id string) (history.File, error) {
61 return history.File{}, nil
62}
63
64func (m *mockHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) {
65 return nil, nil
66}
67
68func (m *mockHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) {
69 return nil, nil
70}
71
72func (m *mockHistoryService) Delete(ctx context.Context, id string) error {
73 return nil
74}
75
76func (m *mockHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error {
77 return nil
78}
79
80func TestApplyEditToContentPartialSuccess(t *testing.T) {
81 t.Parallel()
82
83 content := "line 1\nline 2\nline 3\n"
84
85 // Test successful edit.
86 newContent, err := applyEditToContent(content, MultiEditOperation{
87 OldString: "line 1",
88 NewString: "LINE 1",
89 })
90 require.NoError(t, err)
91 require.Contains(t, newContent, "LINE 1")
92 require.Contains(t, newContent, "line 2")
93
94 // Test failed edit (string not found).
95 _, err = applyEditToContent(content, MultiEditOperation{
96 OldString: "line 99",
97 NewString: "LINE 99",
98 })
99 require.Error(t, err)
100 require.Contains(t, err.Error(), "not found")
101}
102
103func TestMultiEditSequentialApplication(t *testing.T) {
104 t.Parallel()
105
106 tmpDir := t.TempDir()
107 testFile := filepath.Join(tmpDir, "test.txt")
108
109 // Create test file.
110 content := "line 1\nline 2\nline 3\nline 4\n"
111 err := os.WriteFile(testFile, []byte(content), 0o644)
112 require.NoError(t, err)
113
114 // Mock components.
115 lspClients := csync.NewMap[string, *lsp.Client]()
116 permissions := &mockPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()}
117 files := &mockHistoryService{Broker: pubsub.NewBroker[history.File]()}
118
119 // Create multiedit tool.
120 _ = NewMultiEditTool(lspClients, permissions, files, tmpDir)
121
122 // Simulate reading the file first.
123 filetracker.RecordRead(testFile)
124
125 // Manually test the sequential application logic.
126 currentContent := content
127
128 // Apply edits sequentially, tracking failures.
129 edits := []MultiEditOperation{
130 {OldString: "line 1", NewString: "LINE 1"}, // Should succeed
131 {OldString: "line 99", NewString: "LINE 99"}, // Should fail - doesn't exist
132 {OldString: "line 3", NewString: "LINE 3"}, // Should succeed
133 {OldString: "line 2", NewString: "LINE 2"}, // Should succeed - still exists
134 }
135
136 var failedEdits []FailedEdit
137 successCount := 0
138
139 for i, edit := range edits {
140 newContent, err := applyEditToContent(currentContent, edit)
141 if err != nil {
142 failedEdits = append(failedEdits, FailedEdit{
143 Index: i + 1,
144 Error: err.Error(),
145 Edit: edit,
146 })
147 continue
148 }
149 currentContent = newContent
150 successCount++
151 }
152
153 // Verify results.
154 require.Equal(t, 3, successCount, "Expected 3 successful edits")
155 require.Len(t, failedEdits, 1, "Expected 1 failed edit")
156
157 // Check failed edit details.
158 require.Equal(t, 2, failedEdits[0].Index)
159 require.Contains(t, failedEdits[0].Error, "not found")
160
161 // Verify content changes.
162 require.Contains(t, currentContent, "LINE 1")
163 require.Contains(t, currentContent, "LINE 2")
164 require.Contains(t, currentContent, "LINE 3")
165 require.Contains(t, currentContent, "line 4") // Original unchanged
166 require.NotContains(t, currentContent, "LINE 99")
167}
168
169func TestMultiEditAllEditsSucceed(t *testing.T) {
170 t.Parallel()
171
172 content := "line 1\nline 2\nline 3\n"
173
174 edits := []MultiEditOperation{
175 {OldString: "line 1", NewString: "LINE 1"},
176 {OldString: "line 2", NewString: "LINE 2"},
177 {OldString: "line 3", NewString: "LINE 3"},
178 }
179
180 currentContent := content
181 successCount := 0
182
183 for _, edit := range edits {
184 newContent, err := applyEditToContent(currentContent, edit)
185 if err != nil {
186 t.Fatalf("Unexpected error: %v", err)
187 }
188 currentContent = newContent
189 successCount++
190 }
191
192 require.Equal(t, 3, successCount)
193 require.Contains(t, currentContent, "LINE 1")
194 require.Contains(t, currentContent, "LINE 2")
195 require.Contains(t, currentContent, "LINE 3")
196}
197
198func TestMultiEditAllEditsFail(t *testing.T) {
199 t.Parallel()
200
201 content := "line 1\nline 2\n"
202
203 edits := []MultiEditOperation{
204 {OldString: "line 99", NewString: "LINE 99"},
205 {OldString: "line 100", NewString: "LINE 100"},
206 }
207
208 currentContent := content
209 var failedEdits []FailedEdit
210
211 for i, edit := range edits {
212 newContent, err := applyEditToContent(currentContent, edit)
213 if err != nil {
214 failedEdits = append(failedEdits, FailedEdit{
215 Index: i + 1,
216 Error: err.Error(),
217 Edit: edit,
218 })
219 continue
220 }
221 currentContent = newContent
222 }
223
224 require.Len(t, failedEdits, 2)
225 require.Equal(t, content, currentContent, "Content should be unchanged")
226}