write.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"charm.land/fantasy"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/filepathext"
 17	"github.com/charmbracelet/crush/internal/filetracker"
 18	"github.com/charmbracelet/crush/internal/fsext"
 19	"github.com/charmbracelet/crush/internal/history"
 20
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/permission"
 23)
 24
 25//go:embed write.md
 26var writeDescription []byte
 27
 28type WriteParams struct {
 29	FilePath string `json:"file_path" description:"The path to the file to write"`
 30	Content  string `json:"content" description:"The content to write to the file"`
 31}
 32
 33type WritePermissionsParams struct {
 34	FilePath   string `json:"file_path"`
 35	OldContent string `json:"old_content,omitempty"`
 36	NewContent string `json:"new_content,omitempty"`
 37}
 38
 39type WriteResponseMetadata struct {
 40	Diff      string `json:"diff"`
 41	Additions int    `json:"additions"`
 42	Removals  int    `json:"removals"`
 43}
 44
 45const WriteToolName = "write"
 46
 47func NewWriteTool(
 48	lspClients *csync.Map[string, *lsp.Client],
 49	permissions permission.Service,
 50	files history.Service,
 51	filetracker filetracker.Service,
 52	workingDir string,
 53) fantasy.AgentTool {
 54	return fantasy.NewAgentTool(
 55		WriteToolName,
 56		string(writeDescription),
 57		func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 58			if params.FilePath == "" {
 59				return fantasy.NewTextErrorResponse("file_path is required"), nil
 60			}
 61
 62			if params.Content == "" {
 63				return fantasy.NewTextErrorResponse("content is required"), nil
 64			}
 65
 66			sessionID := GetSessionFromContext(ctx)
 67			if sessionID == "" {
 68				return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
 69			}
 70
 71			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 72
 73			fileInfo, err := os.Stat(filePath)
 74			if err == nil {
 75				if fileInfo.IsDir() {
 76					return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
 77				}
 78
 79				modTime := fileInfo.ModTime().Truncate(time.Second)
 80				lastRead := filetracker.LastReadTime(ctx, sessionID, filePath)
 81				if modTime.After(lastRead) {
 82					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
 83						filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
 84				}
 85
 86				oldContent, readErr := os.ReadFile(filePath)
 87				if readErr == nil && string(oldContent) == params.Content {
 88					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
 89				}
 90			} else if !os.IsNotExist(err) {
 91				return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
 92			}
 93
 94			dir := filepath.Dir(filePath)
 95			if err = os.MkdirAll(dir, 0o755); err != nil {
 96				return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
 97			}
 98
 99			oldContent := ""
100			if fileInfo != nil && !fileInfo.IsDir() {
101				oldBytes, readErr := os.ReadFile(filePath)
102				if readErr == nil {
103					oldContent = string(oldBytes)
104				}
105			}
106
107			diff, additions, removals := diff.GenerateDiff(
108				oldContent,
109				params.Content,
110				strings.TrimPrefix(filePath, workingDir),
111			)
112
113			p, err := permissions.Request(ctx,
114				permission.CreatePermissionRequest{
115					SessionID:   sessionID,
116					Path:        fsext.PathOrPrefix(filePath, workingDir),
117					ToolCallID:  call.ID,
118					ToolName:    WriteToolName,
119					Action:      "write",
120					Description: fmt.Sprintf("Create file %s", filePath),
121					Params: WritePermissionsParams{
122						FilePath:   filePath,
123						OldContent: oldContent,
124						NewContent: params.Content,
125					},
126				},
127			)
128			if err != nil {
129				return fantasy.ToolResponse{}, err
130			}
131			if !p {
132				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
133			}
134
135			err = os.WriteFile(filePath, []byte(params.Content), 0o644)
136			if err != nil {
137				return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
138			}
139
140			// Check if file exists in history
141			file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
142			if err != nil {
143				_, err = files.Create(ctx, sessionID, filePath, oldContent)
144				if err != nil {
145					// Log error but don't fail the operation
146					return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
147				}
148			}
149			if file.Content != oldContent {
150				// User manually changed the content; store an intermediate version
151				_, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
152				if err != nil {
153					slog.Error("Error creating file history version", "error", err)
154				}
155			}
156			// Store the new version
157			_, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
158			if err != nil {
159				slog.Error("Error creating file history version", "error", err)
160			}
161
162			filetracker.RecordRead(ctx, sessionID, filePath)
163
164			notifyLSPs(ctx, lspClients, params.FilePath)
165
166			result := fmt.Sprintf("File successfully written: %s", filePath)
167			result = fmt.Sprintf("<result>\n%s\n</result>", result)
168			result += getDiagnostics(filePath, lspClients)
169			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(result),
170				WriteResponseMetadata{
171					Diff:      diff,
172					Additions: additions,
173					Removals:  removals,
174				},
175			), nil
176		})
177}