1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 "charm.land/fantasy"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/diff"
16 "github.com/charmbracelet/crush/internal/filepathext"
17 "github.com/charmbracelet/crush/internal/filetracker"
18 "github.com/charmbracelet/crush/internal/fsext"
19 "github.com/charmbracelet/crush/internal/history"
20
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/permission"
23)
24
25//go:embed write.md
26var writeDescription []byte
27
28type WriteParams struct {
29 FilePath string `json:"file_path" description:"The path to the file to write"`
30 Content string `json:"content" description:"The content to write to the file"`
31}
32
33type WritePermissionsParams struct {
34 FilePath string `json:"file_path"`
35 OldContent string `json:"old_content,omitempty"`
36 NewContent string `json:"new_content,omitempty"`
37}
38
39type writeTool struct {
40 lspClients *csync.Map[string, *lsp.Client]
41 permissions permission.Service
42 files history.Service
43 workingDir string
44}
45
46type WriteResponseMetadata struct {
47 Diff string `json:"diff"`
48 Additions int `json:"additions"`
49 Removals int `json:"removals"`
50}
51
52const WriteToolName = "write"
53
54func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
55 return fantasy.NewAgentTool(
56 WriteToolName,
57 string(writeDescription),
58 func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
59 if params.FilePath == "" {
60 return fantasy.NewTextErrorResponse("file_path is required"), nil
61 }
62
63 if params.Content == "" {
64 return fantasy.NewTextErrorResponse("content is required"), nil
65 }
66
67 filePath := filepathext.SmartJoin(workingDir, params.FilePath)
68
69 fileInfo, err := os.Stat(filePath)
70 if err == nil {
71 if fileInfo.IsDir() {
72 return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
73 }
74
75 modTime := fileInfo.ModTime()
76 lastRead := filetracker.LastReadTime(filePath)
77 if modTime.After(lastRead) {
78 return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
79 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
80 }
81
82 oldContent, readErr := os.ReadFile(filePath)
83 if readErr == nil && string(oldContent) == params.Content {
84 return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
85 }
86 } else if !os.IsNotExist(err) {
87 return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
88 }
89
90 dir := filepath.Dir(filePath)
91 if err = os.MkdirAll(dir, 0o755); err != nil {
92 return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
93 }
94
95 oldContent := ""
96 if fileInfo != nil && !fileInfo.IsDir() {
97 oldBytes, readErr := os.ReadFile(filePath)
98 if readErr == nil {
99 oldContent = string(oldBytes)
100 }
101 }
102
103 sessionID := GetSessionFromContext(ctx)
104 if sessionID == "" {
105 return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
106 }
107
108 diff, additions, removals := diff.GenerateDiff(
109 oldContent,
110 params.Content,
111 strings.TrimPrefix(filePath, workingDir),
112 )
113
114 p, err := permissions.Request(ctx,
115 permission.CreatePermissionRequest{
116 SessionID: sessionID,
117 Path: fsext.PathOrPrefix(filePath, workingDir),
118 ToolCallID: call.ID,
119 ToolName: WriteToolName,
120 Action: "write",
121 Description: fmt.Sprintf("Create file %s", filePath),
122 Params: WritePermissionsParams{
123 FilePath: filePath,
124 OldContent: oldContent,
125 NewContent: params.Content,
126 },
127 },
128 )
129 if err != nil {
130 return fantasy.ToolResponse{}, err
131 }
132 if !p {
133 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
134 }
135
136 err = os.WriteFile(filePath, []byte(params.Content), 0o644)
137 if err != nil {
138 return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
139 }
140
141 // Check if file exists in history
142 file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
143 if err != nil {
144 _, err = files.Create(ctx, sessionID, filePath, oldContent)
145 if err != nil {
146 // Log error but don't fail the operation
147 return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
148 }
149 }
150 if file.Content != oldContent {
151 // User Manually changed the content store an intermediate version
152 _, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
153 if err != nil {
154 slog.Error("Error creating file history version", "error", err)
155 }
156 }
157 // Store the new version
158 _, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
159 if err != nil {
160 slog.Error("Error creating file history version", "error", err)
161 }
162
163 filetracker.RecordWrite(filePath)
164 filetracker.RecordRead(filePath)
165
166 notifyLSPs(ctx, lspClients, params.FilePath)
167
168 result := fmt.Sprintf("File successfully written: %s", filePath)
169 result = fmt.Sprintf("<result>\n%s\n</result>", result)
170 result += getDiagnostics(filePath, lspClients)
171 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(result),
172 WriteResponseMetadata{
173 Diff: diff,
174 Additions: additions,
175 Removals: removals,
176 },
177 ), nil
178 })
179}