write.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"charm.land/fantasy"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/filepathext"
 17	"github.com/charmbracelet/crush/internal/filetracker"
 18	"github.com/charmbracelet/crush/internal/fsext"
 19	"github.com/charmbracelet/crush/internal/history"
 20
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/permission"
 23)
 24
 25//go:embed write.md
 26var writeDescription []byte
 27
 28type WriteParams struct {
 29	FilePath string `json:"file_path" description:"The path to the file to write"`
 30	Content  string `json:"content" description:"The content to write to the file"`
 31}
 32
 33type WritePermissionsParams struct {
 34	FilePath   string `json:"file_path"`
 35	OldContent string `json:"old_content,omitempty"`
 36	NewContent string `json:"new_content,omitempty"`
 37}
 38
 39type writeTool struct {
 40	lspClients  *csync.Map[string, *lsp.Client]
 41	permissions permission.Service
 42	files       history.Service
 43	workingDir  string
 44}
 45
 46type WriteResponseMetadata struct {
 47	Diff      string `json:"diff"`
 48	Additions int    `json:"additions"`
 49	Removals  int    `json:"removals"`
 50}
 51
 52const WriteToolName = "write"
 53
 54func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
 55	return fantasy.NewAgentTool(
 56		WriteToolName,
 57		string(writeDescription),
 58		func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 59			if params.FilePath == "" {
 60				return fantasy.NewTextErrorResponse("file_path is required"), nil
 61			}
 62
 63			if params.Content == "" {
 64				return fantasy.NewTextErrorResponse("content is required"), nil
 65			}
 66
 67			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 68
 69			fileInfo, err := os.Stat(filePath)
 70			if err == nil {
 71				if fileInfo.IsDir() {
 72					return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
 73				}
 74
 75				modTime := fileInfo.ModTime()
 76				lastRead := filetracker.LastReadTime(filePath)
 77				if modTime.After(lastRead) {
 78					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
 79						filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
 80				}
 81
 82				oldContent, readErr := os.ReadFile(filePath)
 83				if readErr == nil && string(oldContent) == params.Content {
 84					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
 85				}
 86			} else if !os.IsNotExist(err) {
 87				return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
 88			}
 89
 90			dir := filepath.Dir(filePath)
 91			if err = os.MkdirAll(dir, 0o755); err != nil {
 92				return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
 93			}
 94
 95			oldContent := ""
 96			if fileInfo != nil && !fileInfo.IsDir() {
 97				oldBytes, readErr := os.ReadFile(filePath)
 98				if readErr == nil {
 99					oldContent = string(oldBytes)
100				}
101			}
102
103			sessionID := GetSessionFromContext(ctx)
104			if sessionID == "" {
105				return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
106			}
107
108			diff, additions, removals := diff.GenerateDiff(
109				oldContent,
110				params.Content,
111				strings.TrimPrefix(filePath, workingDir),
112			)
113
114			p := permissions.Request(
115				permission.CreatePermissionRequest{
116					SessionID:   sessionID,
117					Path:        fsext.PathOrPrefix(filePath, workingDir),
118					ToolCallID:  call.ID,
119					ToolName:    WriteToolName,
120					Action:      "write",
121					Description: fmt.Sprintf("Create file %s", filePath),
122					Params: WritePermissionsParams{
123						FilePath:   filePath,
124						OldContent: oldContent,
125						NewContent: params.Content,
126					},
127				},
128			)
129			if !p {
130				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
131			}
132
133			err = os.WriteFile(filePath, []byte(params.Content), 0o644)
134			if err != nil {
135				return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
136			}
137
138			// Check if file exists in history
139			file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
140			if err != nil {
141				_, err = files.Create(ctx, sessionID, filePath, oldContent)
142				if err != nil {
143					// Log error but don't fail the operation
144					return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
145				}
146			}
147			if file.Content != oldContent {
148				// User Manually changed the content store an intermediate version
149				_, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
150				if err != nil {
151					slog.Error("Error creating file history version", "error", err)
152				}
153			}
154			// Store the new version
155			_, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
156			if err != nil {
157				slog.Error("Error creating file history version", "error", err)
158			}
159
160			filetracker.RecordWrite(filePath)
161			filetracker.RecordRead(filePath)
162
163			notifyLSPs(ctx, lspClients, params.FilePath)
164
165			result := fmt.Sprintf("File successfully written: %s", filePath)
166			result = fmt.Sprintf("<result>\n%s\n</result>", result)
167			result += getDiagnostics(filePath, lspClients)
168			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(result),
169				WriteResponseMetadata{
170					Diff:      diff,
171					Additions: additions,
172					Removals:  removals,
173				},
174			), nil
175		})
176}