1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 "charm.land/fantasy"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/diff"
16 "github.com/charmbracelet/crush/internal/filepathext"
17 "github.com/charmbracelet/crush/internal/filetracker"
18 "github.com/charmbracelet/crush/internal/fsext"
19 "github.com/charmbracelet/crush/internal/history"
20
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/permission"
23)
24
25//go:embed write.md
26var writeDescription []byte
27
28type WriteParams struct {
29 FilePath string `json:"file_path" description:"The path to the file to write"`
30 Content string `json:"content" description:"The content to write to the file"`
31}
32
33type WritePermissionsParams struct {
34 FilePath string `json:"file_path"`
35 OldContent string `json:"old_content,omitempty"`
36 NewContent string `json:"new_content,omitempty"`
37}
38
39type WriteResponseMetadata struct {
40 Diff string `json:"diff"`
41 Additions int `json:"additions"`
42 Removals int `json:"removals"`
43}
44
45const WriteToolName = "write"
46
47func NewWriteTool(
48 lspClients *csync.Map[string, *lsp.Client],
49 permissions permission.Service,
50 files history.Service,
51 filetracker filetracker.Service,
52 workingDir string,
53) fantasy.AgentTool {
54 return fantasy.NewAgentTool(
55 WriteToolName,
56 string(writeDescription),
57 func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
58 if params.FilePath == "" {
59 return fantasy.NewTextErrorResponse("file_path is required"), nil
60 }
61
62 if params.Content == "" {
63 return fantasy.NewTextErrorResponse("content is required"), nil
64 }
65
66 sessionID := GetSessionFromContext(ctx)
67 if sessionID == "" {
68 return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
69 }
70
71 filePath := filepathext.SmartJoin(workingDir, params.FilePath)
72
73 fileInfo, err := os.Stat(filePath)
74 if err == nil {
75 if fileInfo.IsDir() {
76 return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
77 }
78
79 modTime := fileInfo.ModTime().Truncate(time.Second)
80 lastRead := filetracker.LastReadTime(ctx, sessionID, filePath)
81 if modTime.After(lastRead) {
82 return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
83 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
84 }
85
86 oldContent, readErr := os.ReadFile(filePath)
87 if readErr == nil && string(oldContent) == params.Content {
88 return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
89 }
90 } else if !os.IsNotExist(err) {
91 return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
92 }
93
94 dir := filepath.Dir(filePath)
95 if err = os.MkdirAll(dir, 0o755); err != nil {
96 return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
97 }
98
99 oldContent := ""
100 if fileInfo != nil && !fileInfo.IsDir() {
101 oldBytes, readErr := os.ReadFile(filePath)
102 if readErr == nil {
103 oldContent = string(oldBytes)
104 }
105 }
106
107 diff, additions, removals := diff.GenerateDiff(
108 oldContent,
109 params.Content,
110 strings.TrimPrefix(filePath, workingDir),
111 )
112
113 p, err := permissions.Request(ctx,
114 permission.CreatePermissionRequest{
115 SessionID: sessionID,
116 Path: fsext.PathOrPrefix(filePath, workingDir),
117 ToolCallID: call.ID,
118 ToolName: WriteToolName,
119 Action: "write",
120 Description: fmt.Sprintf("Create file %s", filePath),
121 Params: WritePermissionsParams{
122 FilePath: filePath,
123 OldContent: oldContent,
124 NewContent: params.Content,
125 },
126 },
127 )
128 if err != nil {
129 return fantasy.ToolResponse{}, err
130 }
131 if !p {
132 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
133 }
134
135 err = os.WriteFile(filePath, []byte(params.Content), 0o644)
136 if err != nil {
137 return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
138 }
139
140 // Check if file exists in history
141 file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
142 if err != nil {
143 _, err = files.Create(ctx, sessionID, filePath, oldContent)
144 if err != nil {
145 // Log error but don't fail the operation
146 return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
147 }
148 }
149 if file.Content != oldContent {
150 // User manually changed the content; store an intermediate version
151 _, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
152 if err != nil {
153 slog.Error("Error creating file history version", "error", err)
154 }
155 }
156 // Store the new version
157 _, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
158 if err != nil {
159 slog.Error("Error creating file history version", "error", err)
160 }
161
162 filetracker.RecordRead(ctx, sessionID, filePath)
163
164 notifyLSPs(ctx, lspClients, params.FilePath)
165
166 result := fmt.Sprintf("File successfully written: %s", filePath)
167 result = fmt.Sprintf("<result>\n%s\n</result>", result)
168 result += getDiagnostics(filePath, lspClients)
169 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(result),
170 WriteResponseMetadata{
171 Diff: diff,
172 Additions: additions,
173 Removals: removals,
174 },
175 ), nil
176 })
177}