1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/diff"
16 "github.com/charmbracelet/crush/internal/fsext"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/proto"
19
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/permission"
22)
23
24//go:embed write.md
25var writeDescription []byte
26
27type (
28 WriteParams = proto.WriteParams
29 WritePermissionsParams = proto.WritePermissionsParams
30 WriteResponseMetadata = proto.WriteResponseMetadata
31)
32
33type writeTool struct {
34 lspClients *csync.Map[string, *lsp.Client]
35 permissions permission.Service
36 files history.Service
37 workingDir string
38}
39
40const WriteToolName = proto.WriteToolName
41
42func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
43 return &writeTool{
44 lspClients: lspClients,
45 permissions: permissions,
46 files: files,
47 workingDir: workingDir,
48 }
49}
50
51func (w *writeTool) Name() string {
52 return WriteToolName
53}
54
55func (w *writeTool) Info() ToolInfo {
56 return ToolInfo{
57 Name: WriteToolName,
58 Description: string(writeDescription),
59 Parameters: map[string]any{
60 "file_path": map[string]any{
61 "type": "string",
62 "description": "The path to the file to write",
63 },
64 "content": map[string]any{
65 "type": "string",
66 "description": "The content to write to the file",
67 },
68 },
69 Required: []string{"file_path", "content"},
70 }
71}
72
73func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
74 var params WriteParams
75 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
76 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
77 }
78
79 if params.FilePath == "" {
80 return NewTextErrorResponse("file_path is required"), nil
81 }
82
83 if params.Content == "" {
84 return NewTextErrorResponse("content is required"), nil
85 }
86
87 filePath := params.FilePath
88 if !filepath.IsAbs(filePath) {
89 filePath = filepath.Join(w.workingDir, filePath)
90 }
91
92 fileInfo, err := os.Stat(filePath)
93 if err == nil {
94 if fileInfo.IsDir() {
95 return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
96 }
97
98 modTime := fileInfo.ModTime()
99 lastRead := getLastReadTime(filePath)
100 if modTime.After(lastRead) {
101 return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
102 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
103 }
104
105 oldContent, readErr := os.ReadFile(filePath)
106 if readErr == nil && string(oldContent) == params.Content {
107 return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
108 }
109 } else if !os.IsNotExist(err) {
110 return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
111 }
112
113 dir := filepath.Dir(filePath)
114 if err = os.MkdirAll(dir, 0o755); err != nil {
115 return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
116 }
117
118 oldContent := ""
119 if fileInfo != nil && !fileInfo.IsDir() {
120 oldBytes, readErr := os.ReadFile(filePath)
121 if readErr == nil {
122 oldContent = string(oldBytes)
123 }
124 }
125
126 sessionID, messageID := GetContextValues(ctx)
127 if sessionID == "" || messageID == "" {
128 return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
129 }
130
131 diff, additions, removals := diff.GenerateDiff(
132 oldContent,
133 params.Content,
134 strings.TrimPrefix(filePath, w.workingDir),
135 )
136
137 p := w.permissions.Request(
138 permission.CreatePermissionRequest{
139 SessionID: sessionID,
140 Path: fsext.PathOrPrefix(filePath, w.workingDir),
141 ToolCallID: call.ID,
142 ToolName: WriteToolName,
143 Action: "write",
144 Description: fmt.Sprintf("Create file %s", filePath),
145 Params: WritePermissionsParams{
146 FilePath: filePath,
147 OldContent: oldContent,
148 NewContent: params.Content,
149 },
150 },
151 )
152 if !p {
153 return ToolResponse{}, permission.ErrorPermissionDenied
154 }
155
156 err = os.WriteFile(filePath, []byte(params.Content), 0o644)
157 if err != nil {
158 return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
159 }
160
161 // Check if file exists in history
162 file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
163 if err != nil {
164 _, err = w.files.Create(ctx, sessionID, filePath, oldContent)
165 if err != nil {
166 // Log error but don't fail the operation
167 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
168 }
169 }
170 if file.Content != oldContent {
171 // User Manually changed the content store an intermediate version
172 _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
173 if err != nil {
174 slog.Debug("Error creating file history version", "error", err)
175 }
176 }
177 // Store the new version
178 _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
179 if err != nil {
180 slog.Debug("Error creating file history version", "error", err)
181 }
182
183 recordFileWrite(filePath)
184 recordFileRead(filePath)
185
186 notifyLSPs(ctx, w.lspClients, params.FilePath)
187
188 result := fmt.Sprintf("File successfully written: %s", filePath)
189 result = fmt.Sprintf("<result>\n%s\n</result>", result)
190 result += getDiagnostics(filePath, w.lspClients)
191 return WithResponseMetadata(NewTextResponse(result),
192 WriteResponseMetadata{
193 Diff: diff,
194 Additions: additions,
195 Removals: removals,
196 },
197 ), nil
198}