1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/diff"
15 "github.com/charmbracelet/crush/internal/fsext"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/fantasy/ai"
18
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/crush/internal/permission"
21)
22
23//go:embed write.md
24var writeDescription []byte
25
26type WriteParams struct {
27 FilePath string `json:"file_path" description:"The path to the file to write"`
28 Content string `json:"content" description:"The content to write to the file"`
29}
30
31type WritePermissionsParams struct {
32 FilePath string `json:"file_path"`
33 OldContent string `json:"old_content,omitempty"`
34 NewContent string `json:"new_content,omitempty"`
35}
36
37type writeTool struct {
38 lspClients *csync.Map[string, *lsp.Client]
39 permissions permission.Service
40 files history.Service
41 workingDir string
42}
43
44type WriteResponseMetadata struct {
45 Diff string `json:"diff"`
46 Additions int `json:"additions"`
47 Removals int `json:"removals"`
48}
49
50const WriteToolName = "write"
51
52func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) ai.AgentTool {
53 return ai.NewAgentTool(
54 WriteToolName,
55 string(writeDescription),
56 func(ctx context.Context, params WriteParams, call ai.ToolCall) (ai.ToolResponse, error) {
57 if params.FilePath == "" {
58 return ai.NewTextErrorResponse("file_path is required"), nil
59 }
60
61 if params.Content == "" {
62 return ai.NewTextErrorResponse("content is required"), nil
63 }
64
65 filePath := params.FilePath
66 if !filepath.IsAbs(filePath) {
67 filePath = filepath.Join(workingDir, filePath)
68 }
69
70 fileInfo, err := os.Stat(filePath)
71 if err == nil {
72 if fileInfo.IsDir() {
73 return ai.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
74 }
75
76 modTime := fileInfo.ModTime()
77 lastRead := getLastReadTime(filePath)
78 if modTime.After(lastRead) {
79 return ai.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
80 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
81 }
82
83 oldContent, readErr := os.ReadFile(filePath)
84 if readErr == nil && string(oldContent) == params.Content {
85 return ai.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
86 }
87 } else if !os.IsNotExist(err) {
88 return ai.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
89 }
90
91 dir := filepath.Dir(filePath)
92 if err = os.MkdirAll(dir, 0o755); err != nil {
93 return ai.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
94 }
95
96 oldContent := ""
97 if fileInfo != nil && !fileInfo.IsDir() {
98 oldBytes, readErr := os.ReadFile(filePath)
99 if readErr == nil {
100 oldContent = string(oldBytes)
101 }
102 }
103
104 sessionID := GetSessionFromContext(ctx)
105 if sessionID == "" {
106 return ai.ToolResponse{}, fmt.Errorf("session_id is required")
107 }
108
109 diff, additions, removals := diff.GenerateDiff(
110 oldContent,
111 params.Content,
112 strings.TrimPrefix(filePath, workingDir),
113 )
114
115 p := permissions.Request(
116 permission.CreatePermissionRequest{
117 SessionID: sessionID,
118 Path: fsext.PathOrPrefix(filePath, workingDir),
119 ToolCallID: call.ID,
120 ToolName: WriteToolName,
121 Action: "write",
122 Description: fmt.Sprintf("Create file %s", filePath),
123 Params: WritePermissionsParams{
124 FilePath: filePath,
125 OldContent: oldContent,
126 NewContent: params.Content,
127 },
128 },
129 )
130 if !p {
131 return ai.ToolResponse{}, permission.ErrorPermissionDenied
132 }
133
134 err = os.WriteFile(filePath, []byte(params.Content), 0o644)
135 if err != nil {
136 return ai.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
137 }
138
139 // Check if file exists in history
140 file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
141 if err != nil {
142 _, err = files.Create(ctx, sessionID, filePath, oldContent)
143 if err != nil {
144 // Log error but don't fail the operation
145 return ai.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
146 }
147 }
148 if file.Content != oldContent {
149 // User Manually changed the content store an intermediate version
150 _, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
151 if err != nil {
152 slog.Debug("Error creating file history version", "error", err)
153 }
154 }
155 // Store the new version
156 _, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
157 if err != nil {
158 slog.Debug("Error creating file history version", "error", err)
159 }
160
161 recordFileWrite(filePath)
162 recordFileRead(filePath)
163
164 notifyLSPs(ctx, lspClients, params.FilePath)
165
166 result := fmt.Sprintf("File successfully written: %s", filePath)
167 result = fmt.Sprintf("<result>\n%s\n</result>", result)
168 result += getDiagnostics(filePath, lspClients)
169 return ai.WithResponseMetadata(ai.NewTextResponse(result),
170 WriteResponseMetadata{
171 Diff: diff,
172 Additions: additions,
173 Removals: removals,
174 },
175 ), nil
176 })
177}