write.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"charm.land/fantasy"
 14	"github.com/charmbracelet/crush/internal/diff"
 15	"github.com/charmbracelet/crush/internal/filepathext"
 16	"github.com/charmbracelet/crush/internal/filetracker"
 17	"github.com/charmbracelet/crush/internal/fsext"
 18	"github.com/charmbracelet/crush/internal/history"
 19
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/permission"
 22)
 23
 24//go:embed write.md
 25var writeDescription string
 26
 27type WriteParams struct {
 28	FilePath string `json:"file_path" description:"The path to the file to write"`
 29	Content  string `json:"content" description:"The content to write to the file"`
 30}
 31
 32type WritePermissionsParams struct {
 33	FilePath   string `json:"file_path"`
 34	OldContent string `json:"old_content,omitempty"`
 35	NewContent string `json:"new_content,omitempty"`
 36}
 37
 38type WriteResponseMetadata struct {
 39	Diff      string `json:"diff"`
 40	Additions int    `json:"additions"`
 41	Removals  int    `json:"removals"`
 42}
 43
 44const WriteToolName = "write"
 45
 46func NewWriteTool(
 47	lspManager *lsp.Manager,
 48	permissions permission.Service,
 49	files history.Service,
 50	filetracker filetracker.Service,
 51	workingDir string,
 52) fantasy.AgentTool {
 53	return fantasy.NewAgentTool(
 54		WriteToolName,
 55		writeDescription,
 56		func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 57			if params.FilePath == "" {
 58				return fantasy.NewTextErrorResponse("file_path is required"), nil
 59			}
 60
 61			sessionID := GetSessionFromContext(ctx)
 62			if sessionID == "" {
 63				return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
 64			}
 65
 66			filePath := filepathext.SmartJoin(workingDir, params.FilePath)
 67
 68			fileInfo, err := os.Stat(filePath)
 69			if err == nil {
 70				if fileInfo.IsDir() {
 71					return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
 72				}
 73
 74				modTime := fileInfo.ModTime().Truncate(time.Second)
 75				lastRead := filetracker.LastReadTime(ctx, sessionID, filePath)
 76				if modTime.After(lastRead) {
 77					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
 78						filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
 79				}
 80
 81				oldContent, readErr := os.ReadFile(filePath)
 82				if readErr == nil && string(oldContent) == params.Content {
 83					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
 84				}
 85			} else if !os.IsNotExist(err) {
 86				return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
 87			}
 88
 89			dir := filepath.Dir(filePath)
 90			if err = os.MkdirAll(dir, 0o755); err != nil {
 91				return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
 92			}
 93
 94			oldContent := ""
 95			if fileInfo != nil && !fileInfo.IsDir() {
 96				oldBytes, readErr := os.ReadFile(filePath)
 97				if readErr == nil {
 98					oldContent = string(oldBytes)
 99				}
100			}
101
102			diff, additions, removals := diff.GenerateDiff(
103				oldContent,
104				params.Content,
105				strings.TrimPrefix(filePath, workingDir),
106			)
107
108			p, err := permissions.Request(
109				ctx,
110				permission.CreatePermissionRequest{
111					SessionID:   sessionID,
112					Path:        fsext.PathOrPrefix(filePath, workingDir),
113					ToolCallID:  call.ID,
114					ToolName:    WriteToolName,
115					Action:      "write",
116					Description: fmt.Sprintf("Create file %s", filePath),
117					Params: WritePermissionsParams{
118						FilePath:   filePath,
119						OldContent: oldContent,
120						NewContent: params.Content,
121					},
122				},
123			)
124			if err != nil {
125				return fantasy.ToolResponse{}, err
126			}
127			if !p {
128				resp := NewPermissionDeniedResponse()
129				resp = fantasy.WithResponseMetadata(resp, WriteResponseMetadata{
130					Diff:      diff,
131					Additions: additions,
132					Removals:  removals,
133				})
134				return resp, nil
135			}
136
137			err = os.WriteFile(filePath, []byte(params.Content), 0o644)
138			if err != nil {
139				return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
140			}
141
142			// Check if file exists in history
143			file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
144			if err != nil {
145				_, err = files.Create(ctx, sessionID, filePath, oldContent)
146				if err != nil {
147					// Log error but don't fail the operation
148					return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
149				}
150			}
151			if file.Content != oldContent {
152				// User manually changed the content; store an intermediate version
153				_, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
154				if err != nil {
155					slog.Error("Error creating file history version", "error", err)
156				}
157			}
158			// Store the new version
159			_, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
160			if err != nil {
161				slog.Error("Error creating file history version", "error", err)
162			}
163
164			filetracker.RecordRead(ctx, sessionID, filePath)
165
166			notifyLSPs(ctx, lspManager, params.FilePath)
167
168			result := fmt.Sprintf("File successfully written: %s", filePath)
169			result = fmt.Sprintf("<result>\n%s\n</result>", result)
170			result += getDiagnostics(filePath, lspManager)
171			return fantasy.WithResponseMetadata(
172				fantasy.NewTextResponse(result),
173				WriteResponseMetadata{
174					Diff:      diff,
175					Additions: additions,
176					Removals:  removals,
177				},
178			), nil
179		},
180	)
181}