1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/fsext"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/proto"
 19
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/permission"
 22)
 23
 24//go:embed write.md
 25var writeDescription []byte
 26
 27type (
 28	WriteParams            = proto.WriteParams
 29	WritePermissionsParams = proto.WritePermissionsParams
 30	WriteResponseMetadata  = proto.WriteResponseMetadata
 31)
 32
 33type writeTool struct {
 34	lspClients  *csync.Map[string, *lsp.Client]
 35	permissions permission.Service
 36	files       history.Service
 37	workingDir  string
 38}
 39
 40const WriteToolName = proto.WriteToolName
 41
 42func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 43	return &writeTool{
 44		lspClients:  lspClients,
 45		permissions: permissions,
 46		files:       files,
 47		workingDir:  workingDir,
 48	}
 49}
 50
 51func (w *writeTool) Name() string {
 52	return WriteToolName
 53}
 54
 55func (w *writeTool) Info() ToolInfo {
 56	return ToolInfo{
 57		Name:        WriteToolName,
 58		Description: string(writeDescription),
 59		Parameters: map[string]any{
 60			"file_path": map[string]any{
 61				"type":        "string",
 62				"description": "The path to the file to write",
 63			},
 64			"content": map[string]any{
 65				"type":        "string",
 66				"description": "The content to write to the file",
 67			},
 68		},
 69		Required: []string{"file_path", "content"},
 70	}
 71}
 72
 73func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 74	var params WriteParams
 75	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 76		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 77	}
 78
 79	if params.FilePath == "" {
 80		return NewTextErrorResponse("file_path is required"), nil
 81	}
 82
 83	if params.Content == "" {
 84		return NewTextErrorResponse("content is required"), nil
 85	}
 86
 87	filePath := params.FilePath
 88	if !filepath.IsAbs(filePath) {
 89		filePath = filepath.Join(w.workingDir, filePath)
 90	}
 91
 92	fileInfo, err := os.Stat(filePath)
 93	if err == nil {
 94		if fileInfo.IsDir() {
 95			return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
 96		}
 97
 98		modTime := fileInfo.ModTime()
 99		lastRead := getLastReadTime(filePath)
100		if modTime.After(lastRead) {
101			return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
102				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
103		}
104
105		oldContent, readErr := os.ReadFile(filePath)
106		if readErr == nil && string(oldContent) == params.Content {
107			return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
108		}
109	} else if !os.IsNotExist(err) {
110		return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
111	}
112
113	dir := filepath.Dir(filePath)
114	if err = os.MkdirAll(dir, 0o755); err != nil {
115		return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
116	}
117
118	oldContent := ""
119	if fileInfo != nil && !fileInfo.IsDir() {
120		oldBytes, readErr := os.ReadFile(filePath)
121		if readErr == nil {
122			oldContent = string(oldBytes)
123		}
124	}
125
126	sessionID, messageID := GetContextValues(ctx)
127	if sessionID == "" || messageID == "" {
128		return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
129	}
130
131	diff, additions, removals := diff.GenerateDiff(
132		oldContent,
133		params.Content,
134		strings.TrimPrefix(filePath, w.workingDir),
135	)
136
137	p := w.permissions.Request(
138		permission.CreatePermissionRequest{
139			SessionID:   sessionID,
140			Path:        fsext.PathOrPrefix(filePath, w.workingDir),
141			ToolCallID:  call.ID,
142			ToolName:    WriteToolName,
143			Action:      "write",
144			Description: fmt.Sprintf("Create file %s", filePath),
145			Params: WritePermissionsParams{
146				FilePath:   filePath,
147				OldContent: oldContent,
148				NewContent: params.Content,
149			},
150		},
151	)
152	if !p {
153		return ToolResponse{}, permission.ErrorPermissionDenied
154	}
155
156	err = os.WriteFile(filePath, []byte(params.Content), 0o644)
157	if err != nil {
158		return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
159	}
160
161	// Check if file exists in history
162	file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
163	if err != nil {
164		_, err = w.files.Create(ctx, sessionID, filePath, oldContent)
165		if err != nil {
166			// Log error but don't fail the operation
167			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
168		}
169	}
170	if file.Content != oldContent {
171		// User Manually changed the content store an intermediate version
172		_, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
173		if err != nil {
174			slog.Debug("Error creating file history version", "error", err)
175		}
176	}
177	// Store the new version
178	_, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
179	if err != nil {
180		slog.Debug("Error creating file history version", "error", err)
181	}
182
183	recordFileWrite(filePath)
184	recordFileRead(filePath)
185
186	notifyLSPs(ctx, w.lspClients, params.FilePath)
187
188	result := fmt.Sprintf("File successfully written: %s", filePath)
189	result = fmt.Sprintf("<result>\n%s\n</result>", result)
190	result += getDiagnostics(filePath, w.lspClients)
191	return WithResponseMetadata(NewTextResponse(result),
192		WriteResponseMetadata{
193			Diff:      diff,
194			Additions: additions,
195			Removals:  removals,
196		},
197	), nil
198}