1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/diff"
16 "github.com/charmbracelet/crush/internal/fsext"
17 "github.com/charmbracelet/crush/internal/history"
18
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/crush/internal/permission"
21)
22
23//go:embed write.md
24var writeDescription []byte
25
26type WriteParams struct {
27 FilePath string `json:"file_path"`
28 Content string `json:"content"`
29}
30
31type WritePermissionsParams struct {
32 FilePath string `json:"file_path"`
33 OldContent string `json:"old_content,omitempty"`
34 NewContent string `json:"new_content,omitempty"`
35}
36
37type writeTool struct {
38 lspClients *csync.Map[string, *lsp.Client]
39 permissions permission.Service
40 files history.Service
41 workingDir string
42}
43
44type WriteResponseMetadata struct {
45 Diff string `json:"diff"`
46 Additions int `json:"additions"`
47 Removals int `json:"removals"`
48}
49
50const WriteToolName = "write"
51
52func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
53 return &writeTool{
54 lspClients: lspClients,
55 permissions: permissions,
56 files: files,
57 workingDir: workingDir,
58 }
59}
60
61func (w *writeTool) Name() string {
62 return WriteToolName
63}
64
65func (w *writeTool) Info() ToolInfo {
66 return ToolInfo{
67 Name: WriteToolName,
68 Description: string(writeDescription),
69 Parameters: map[string]any{
70 "file_path": map[string]any{
71 "type": "string",
72 "description": "The path to the file to write",
73 },
74 "content": map[string]any{
75 "type": "string",
76 "description": "The content to write to the file",
77 },
78 },
79 Required: []string{"file_path", "content"},
80 }
81}
82
83func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
84 var params WriteParams
85 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
86 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
87 }
88
89 if params.FilePath == "" {
90 return NewTextErrorResponse("file_path is required"), nil
91 }
92
93 if params.Content == "" {
94 return NewTextErrorResponse("content is required"), nil
95 }
96
97 filePath := params.FilePath
98 if !filepath.IsAbs(filePath) {
99 filePath = filepath.Join(w.workingDir, filePath)
100 }
101
102 fileInfo, err := os.Stat(filePath)
103 if err == nil {
104 if fileInfo.IsDir() {
105 return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
106 }
107
108 modTime := fileInfo.ModTime()
109 lastRead := getLastReadTime(filePath)
110 if modTime.After(lastRead) {
111 return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
112 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
113 }
114
115 oldContent, readErr := os.ReadFile(filePath)
116 if readErr == nil && string(oldContent) == params.Content {
117 return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
118 }
119 } else if !os.IsNotExist(err) {
120 return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
121 }
122
123 dir := filepath.Dir(filePath)
124 if err = os.MkdirAll(dir, 0o755); err != nil {
125 return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
126 }
127
128 oldContent := ""
129 if fileInfo != nil && !fileInfo.IsDir() {
130 oldBytes, readErr := os.ReadFile(filePath)
131 if readErr == nil {
132 oldContent = string(oldBytes)
133 }
134 }
135
136 sessionID, messageID := GetContextValues(ctx)
137 if sessionID == "" || messageID == "" {
138 return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
139 }
140
141 diff, additions, removals := diff.GenerateDiff(
142 oldContent,
143 params.Content,
144 strings.TrimPrefix(filePath, w.workingDir),
145 )
146
147 p := w.permissions.Request(
148 permission.CreatePermissionRequest{
149 SessionID: sessionID,
150 Path: fsext.PathOrPrefix(filePath, w.workingDir),
151 ToolCallID: call.ID,
152 ToolName: WriteToolName,
153 Action: "write",
154 Description: fmt.Sprintf("Create file %s", filePath),
155 Params: WritePermissionsParams{
156 FilePath: filePath,
157 OldContent: oldContent,
158 NewContent: params.Content,
159 },
160 },
161 )
162 if !p {
163 return ToolResponse{}, permission.ErrorPermissionDenied
164 }
165
166 err = os.WriteFile(filePath, []byte(params.Content), 0o644)
167 if err != nil {
168 return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
169 }
170
171 // Check if file exists in history
172 file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
173 if err != nil {
174 _, err = w.files.Create(ctx, sessionID, filePath, oldContent)
175 if err != nil {
176 // Log error but don't fail the operation
177 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
178 }
179 }
180 if file.Content != oldContent {
181 // User Manually changed the content store an intermediate version
182 _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
183 if err != nil {
184 slog.Debug("Error creating file history version", "error", err)
185 }
186 }
187 // Store the new version
188 _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
189 if err != nil {
190 slog.Debug("Error creating file history version", "error", err)
191 }
192
193 recordFileWrite(filePath)
194 recordFileRead(filePath)
195
196 notifyLSPs(ctx, w.lspClients, params.FilePath)
197
198 result := fmt.Sprintf("File successfully written: %s", filePath)
199 result = fmt.Sprintf("<result>\n%s\n</result>", result)
200 result += getDiagnostics(filePath, w.lspClients)
201 return WithResponseMetadata(NewTextResponse(result),
202 WriteResponseMetadata{
203 Diff: diff,
204 Additions: additions,
205 Removals: removals,
206 },
207 ), nil
208}