1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 "charm.land/fantasy"
14 "github.com/charmbracelet/crush/internal/diff"
15 "github.com/charmbracelet/crush/internal/filepathext"
16 "github.com/charmbracelet/crush/internal/filetracker"
17 "github.com/charmbracelet/crush/internal/fsext"
18 "github.com/charmbracelet/crush/internal/history"
19
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/permission"
22)
23
24//go:embed write.md
25var writeDescription string
26
27type WriteParams struct {
28 FilePath string `json:"file_path" description:"The path to the file to write"`
29 Content string `json:"content" description:"The content to write to the file"`
30}
31
32type WritePermissionsParams struct {
33 FilePath string `json:"file_path"`
34 OldContent string `json:"old_content,omitempty"`
35 NewContent string `json:"new_content,omitempty"`
36}
37
38type WriteResponseMetadata struct {
39 Diff string `json:"diff"`
40 Additions int `json:"additions"`
41 Removals int `json:"removals"`
42}
43
44const WriteToolName = "write"
45
46func NewWriteTool(
47 lspManager *lsp.Manager,
48 permissions permission.Service,
49 files history.Service,
50 filetracker filetracker.Service,
51 workingDir string,
52) fantasy.AgentTool {
53 return fantasy.NewAgentTool(
54 WriteToolName,
55 writeDescription,
56 func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
57 if params.FilePath == "" {
58 return fantasy.NewTextErrorResponse("file_path is required"), nil
59 }
60
61 sessionID := GetSessionFromContext(ctx)
62 if sessionID == "" {
63 return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
64 }
65
66 filePath := filepathext.SmartJoin(workingDir, params.FilePath)
67
68 fileInfo, err := os.Stat(filePath)
69 if err == nil {
70 if fileInfo.IsDir() {
71 return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
72 }
73
74 modTime := fileInfo.ModTime().Truncate(time.Second)
75 lastRead := filetracker.LastReadTime(ctx, sessionID, filePath)
76 if modTime.After(lastRead) {
77 return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
78 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
79 }
80
81 oldContent, readErr := os.ReadFile(filePath)
82 if readErr == nil && string(oldContent) == params.Content {
83 return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
84 }
85 } else if !os.IsNotExist(err) {
86 return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
87 }
88
89 dir := filepath.Dir(filePath)
90 if err = os.MkdirAll(dir, 0o755); err != nil {
91 return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
92 }
93
94 oldContent := ""
95 if fileInfo != nil && !fileInfo.IsDir() {
96 oldBytes, readErr := os.ReadFile(filePath)
97 if readErr == nil {
98 oldContent = string(oldBytes)
99 }
100 }
101
102 diff, additions, removals := diff.GenerateDiff(
103 oldContent,
104 params.Content,
105 strings.TrimPrefix(filePath, workingDir),
106 )
107
108 p, err := permissions.Request(
109 ctx,
110 permission.CreatePermissionRequest{
111 SessionID: sessionID,
112 Path: fsext.PathOrPrefix(filePath, workingDir),
113 ToolCallID: call.ID,
114 ToolName: WriteToolName,
115 Action: "write",
116 Description: fmt.Sprintf("Create file %s", filePath),
117 Params: WritePermissionsParams{
118 FilePath: filePath,
119 OldContent: oldContent,
120 NewContent: params.Content,
121 },
122 },
123 )
124 if err != nil {
125 return fantasy.ToolResponse{}, err
126 }
127 if !p {
128 resp := NewPermissionDeniedResponse()
129 resp = fantasy.WithResponseMetadata(resp, WriteResponseMetadata{
130 Diff: diff,
131 Additions: additions,
132 Removals: removals,
133 })
134 return resp, nil
135 }
136
137 err = os.WriteFile(filePath, []byte(params.Content), 0o644)
138 if err != nil {
139 return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
140 }
141
142 // Check if file exists in history
143 file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
144 if err != nil {
145 _, err = files.Create(ctx, sessionID, filePath, oldContent)
146 if err != nil {
147 // Log error but don't fail the operation
148 return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
149 }
150 }
151 if file.Content != oldContent {
152 // User manually changed the content; store an intermediate version
153 _, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
154 if err != nil {
155 slog.Error("Error creating file history version", "error", err)
156 }
157 }
158 // Store the new version
159 _, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
160 if err != nil {
161 slog.Error("Error creating file history version", "error", err)
162 }
163
164 filetracker.RecordRead(ctx, sessionID, filePath)
165
166 notifyLSPs(ctx, lspManager, params.FilePath)
167
168 result := fmt.Sprintf("File successfully written: %s", filePath)
169 result = fmt.Sprintf("<result>\n%s\n</result>", result)
170 result += getDiagnostics(filePath, lspManager)
171 return fantasy.WithResponseMetadata(
172 fantasy.NewTextResponse(result),
173 WriteResponseMetadata{
174 Diff: diff,
175 Additions: additions,
176 Removals: removals,
177 },
178 ), nil
179 },
180 )
181}