1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/fsext"
 17	"github.com/charmbracelet/crush/internal/history"
 18
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/crush/internal/permission"
 21)
 22
 23//go:embed write.md
 24var writeDescription []byte
 25
 26type WriteParams struct {
 27	FilePath string `json:"file_path"`
 28	Content  string `json:"content"`
 29}
 30
 31type WritePermissionsParams struct {
 32	FilePath   string `json:"file_path"`
 33	OldContent string `json:"old_content,omitempty"`
 34	NewContent string `json:"new_content,omitempty"`
 35}
 36
 37type writeTool struct {
 38	lspClients  *csync.Map[string, *lsp.Client]
 39	permissions permission.Service
 40	files       history.Service
 41	workingDir  string
 42}
 43
 44type WriteResponseMetadata struct {
 45	Diff      string `json:"diff"`
 46	Additions int    `json:"additions"`
 47	Removals  int    `json:"removals"`
 48}
 49
 50const WriteToolName = "write"
 51
 52func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 53	return &writeTool{
 54		lspClients:  lspClients,
 55		permissions: permissions,
 56		files:       files,
 57		workingDir:  workingDir,
 58	}
 59}
 60
 61func (w *writeTool) Name() string {
 62	return WriteToolName
 63}
 64
 65func (w *writeTool) Info() ToolInfo {
 66	return ToolInfo{
 67		Name:        WriteToolName,
 68		Description: string(writeDescription),
 69		Parameters: map[string]any{
 70			"file_path": map[string]any{
 71				"type":        "string",
 72				"description": "The path to the file to write",
 73			},
 74			"content": map[string]any{
 75				"type":        "string",
 76				"description": "The content to write to the file",
 77			},
 78		},
 79		Required: []string{"file_path", "content"},
 80	}
 81}
 82
 83func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 84	var params WriteParams
 85	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 86		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 87	}
 88
 89	if params.FilePath == "" {
 90		return NewTextErrorResponse("file_path is required"), nil
 91	}
 92
 93	if params.Content == "" {
 94		return NewTextErrorResponse("content is required"), nil
 95	}
 96
 97	filePath := params.FilePath
 98	if !filepath.IsAbs(filePath) {
 99		filePath = filepath.Join(w.workingDir, filePath)
100	}
101
102	fileInfo, err := os.Stat(filePath)
103	if err == nil {
104		if fileInfo.IsDir() {
105			return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
106		}
107
108		modTime := fileInfo.ModTime()
109		lastRead := getLastReadTime(filePath)
110		if modTime.After(lastRead) {
111			return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
112				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
113		}
114
115		oldContent, readErr := os.ReadFile(filePath)
116		if readErr == nil && string(oldContent) == params.Content {
117			return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
118		}
119	} else if !os.IsNotExist(err) {
120		return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
121	}
122
123	dir := filepath.Dir(filePath)
124	if err = os.MkdirAll(dir, 0o755); err != nil {
125		return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
126	}
127
128	oldContent := ""
129	if fileInfo != nil && !fileInfo.IsDir() {
130		oldBytes, readErr := os.ReadFile(filePath)
131		if readErr == nil {
132			oldContent = string(oldBytes)
133		}
134	}
135
136	sessionID, messageID := GetContextValues(ctx)
137	if sessionID == "" || messageID == "" {
138		return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
139	}
140
141	diff, additions, removals := diff.GenerateDiff(
142		oldContent,
143		params.Content,
144		strings.TrimPrefix(filePath, w.workingDir),
145	)
146
147	p := w.permissions.Request(
148		permission.CreatePermissionRequest{
149			SessionID:   sessionID,
150			Path:        fsext.PathOrPrefix(filePath, w.workingDir),
151			ToolCallID:  call.ID,
152			ToolName:    WriteToolName,
153			Action:      "write",
154			Description: fmt.Sprintf("Create file %s", filePath),
155			Params: WritePermissionsParams{
156				FilePath:   filePath,
157				OldContent: oldContent,
158				NewContent: params.Content,
159			},
160		},
161	)
162	if !p {
163		return ToolResponse{}, permission.ErrorPermissionDenied
164	}
165
166	err = os.WriteFile(filePath, []byte(params.Content), 0o644)
167	if err != nil {
168		return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
169	}
170
171	// Check if file exists in history
172	file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
173	if err != nil {
174		_, err = w.files.Create(ctx, sessionID, filePath, oldContent)
175		if err != nil {
176			// Log error but don't fail the operation
177			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
178		}
179	}
180	if file.Content != oldContent {
181		// User Manually changed the content store an intermediate version
182		_, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
183		if err != nil {
184			slog.Debug("Error creating file history version", "error", err)
185		}
186	}
187	// Store the new version
188	_, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
189	if err != nil {
190		slog.Debug("Error creating file history version", "error", err)
191	}
192
193	recordFileWrite(filePath)
194	recordFileRead(filePath)
195
196	notifyLSPs(ctx, w.lspClients, params.FilePath)
197
198	result := fmt.Sprintf("File successfully written: %s", filePath)
199	result = fmt.Sprintf("<result>\n%s\n</result>", result)
200	result += getDiagnostics(filePath, w.lspClients)
201	return WithResponseMetadata(NewTextResponse(result),
202		WriteResponseMetadata{
203			Diff:      diff,
204			Additions: additions,
205			Removals:  removals,
206		},
207	), nil
208}