write.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"charm.land/fantasy"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/filepathext"
 17	"github.com/charmbracelet/crush/internal/filetracker"
 18	"github.com/charmbracelet/crush/internal/fsext"
 19	"github.com/charmbracelet/crush/internal/history"
 20
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/permission"
 23)
 24
 25//go:embed write.md
 26var writeDescription []byte
 27
 28type WriteParams struct {
 29	FilePath string `json:"file_path" description:"The path to the file to write"`
 30	Content  string `json:"content" description:"The content to write to the file"`
 31}
 32
 33type WritePermissionsParams struct {
 34	FilePath   string `json:"file_path"`
 35	OldContent string `json:"old_content,omitempty"`
 36	NewContent string `json:"new_content,omitempty"`
 37}
 38
 39type WriteResponseMetadata struct {
 40	Diff      string `json:"diff"`
 41	Additions int    `json:"additions"`
 42	Removals  int    `json:"removals"`
 43}
 44
 45const WriteToolName = "write"
 46
 47func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
 48	return fantasy.NewAgentTool(
 49		WriteToolName,
 50		string(writeDescription),
 51		func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 52			if params.FilePath == "" {
 53				return fantasy.NewTextErrorResponse("file_path is required"), nil
 54			}
 55
 56			if params.Content == "" {
 57				return fantasy.NewTextErrorResponse("content is required"), nil
 58			}
 59
 60			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 61
 62			fileInfo, err := os.Stat(filePath)
 63			if err == nil {
 64				if fileInfo.IsDir() {
 65					return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
 66				}
 67
 68				modTime := fileInfo.ModTime()
 69				lastRead := filetracker.LastReadTime(filePath)
 70				if modTime.After(lastRead) {
 71					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
 72						filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
 73				}
 74
 75				oldContent, readErr := os.ReadFile(filePath)
 76				if readErr == nil && string(oldContent) == params.Content {
 77					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
 78				}
 79			} else if !os.IsNotExist(err) {
 80				return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
 81			}
 82
 83			dir := filepath.Dir(filePath)
 84			if err = os.MkdirAll(dir, 0o755); err != nil {
 85				return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
 86			}
 87
 88			oldContent := ""
 89			if fileInfo != nil && !fileInfo.IsDir() {
 90				oldBytes, readErr := os.ReadFile(filePath)
 91				if readErr == nil {
 92					oldContent = string(oldBytes)
 93				}
 94			}
 95
 96			sessionID := GetSessionFromContext(ctx)
 97			if sessionID == "" {
 98				return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
 99			}
100
101			diff, additions, removals := diff.GenerateDiff(
102				oldContent,
103				params.Content,
104				strings.TrimPrefix(filePath, workingDir),
105			)
106
107			p, err := permissions.Request(ctx,
108				permission.CreatePermissionRequest{
109					SessionID:   sessionID,
110					Path:        fsext.PathOrPrefix(filePath, workingDir),
111					ToolCallID:  call.ID,
112					ToolName:    WriteToolName,
113					Action:      "write",
114					Description: fmt.Sprintf("Create file %s", filePath),
115					Params: WritePermissionsParams{
116						FilePath:   filePath,
117						OldContent: oldContent,
118						NewContent: params.Content,
119					},
120				},
121			)
122			if err != nil {
123				return fantasy.ToolResponse{}, err
124			}
125			if !p {
126				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
127			}
128
129			err = os.WriteFile(filePath, []byte(params.Content), 0o644)
130			if err != nil {
131				return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
132			}
133
134			// Check if file exists in history
135			file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
136			if err != nil {
137				_, err = files.Create(ctx, sessionID, filePath, oldContent)
138				if err != nil {
139					// Log error but don't fail the operation
140					return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
141				}
142			}
143			if file.Content != oldContent {
144				// User manually changed the content; store an intermediate version
145				_, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
146				if err != nil {
147					slog.Error("Error creating file history version", "error", err)
148				}
149			}
150			// Store the new version
151			_, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
152			if err != nil {
153				slog.Error("Error creating file history version", "error", err)
154			}
155
156			filetracker.RecordWrite(filePath)
157			filetracker.RecordRead(filePath)
158
159			notifyLSPs(ctx, lspClients, params.FilePath)
160
161			result := fmt.Sprintf("File successfully written: %s", filePath)
162			result = fmt.Sprintf("<result>\n%s\n</result>", result)
163			result += getDiagnostics(filePath, lspClients)
164			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(result),
165				WriteResponseMetadata{
166					Diff:      diff,
167					Additions: additions,
168					Removals:  removals,
169				},
170			), nil
171		})
172}