write.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/diff"
 15	"github.com/charmbracelet/crush/internal/fsext"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/fantasy/ai"
 18
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/crush/internal/permission"
 21)
 22
 23//go:embed write.md
 24var writeDescription []byte
 25
 26type WriteParams struct {
 27	FilePath string `json:"file_path" description:"The path to the file to write"`
 28	Content  string `json:"content" description:"The content to write to the file"`
 29}
 30
 31type WritePermissionsParams struct {
 32	FilePath   string `json:"file_path"`
 33	OldContent string `json:"old_content,omitempty"`
 34	NewContent string `json:"new_content,omitempty"`
 35}
 36
 37type writeTool struct {
 38	lspClients  *csync.Map[string, *lsp.Client]
 39	permissions permission.Service
 40	files       history.Service
 41	workingDir  string
 42}
 43
 44type WriteResponseMetadata struct {
 45	Diff      string `json:"diff"`
 46	Additions int    `json:"additions"`
 47	Removals  int    `json:"removals"`
 48}
 49
 50const WriteToolName = "write"
 51
 52func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) ai.AgentTool {
 53	return ai.NewAgentTool(
 54		WriteToolName,
 55		string(writeDescription),
 56		func(ctx context.Context, params WriteParams, call ai.ToolCall) (ai.ToolResponse, error) {
 57			if params.FilePath == "" {
 58				return ai.NewTextErrorResponse("file_path is required"), nil
 59			}
 60
 61			if params.Content == "" {
 62				return ai.NewTextErrorResponse("content is required"), nil
 63			}
 64
 65			filePath := params.FilePath
 66			if !filepath.IsAbs(filePath) {
 67				filePath = filepath.Join(workingDir, filePath)
 68			}
 69
 70			fileInfo, err := os.Stat(filePath)
 71			if err == nil {
 72				if fileInfo.IsDir() {
 73					return ai.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
 74				}
 75
 76				modTime := fileInfo.ModTime()
 77				lastRead := getLastReadTime(filePath)
 78				if modTime.After(lastRead) {
 79					return ai.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
 80						filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
 81				}
 82
 83				oldContent, readErr := os.ReadFile(filePath)
 84				if readErr == nil && string(oldContent) == params.Content {
 85					return ai.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
 86				}
 87			} else if !os.IsNotExist(err) {
 88				return ai.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
 89			}
 90
 91			dir := filepath.Dir(filePath)
 92			if err = os.MkdirAll(dir, 0o755); err != nil {
 93				return ai.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
 94			}
 95
 96			oldContent := ""
 97			if fileInfo != nil && !fileInfo.IsDir() {
 98				oldBytes, readErr := os.ReadFile(filePath)
 99				if readErr == nil {
100					oldContent = string(oldBytes)
101				}
102			}
103
104			sessionID := GetSessionFromContext(ctx)
105			if sessionID == "" {
106				return ai.ToolResponse{}, fmt.Errorf("session_id is required")
107			}
108
109			diff, additions, removals := diff.GenerateDiff(
110				oldContent,
111				params.Content,
112				strings.TrimPrefix(filePath, workingDir),
113			)
114
115			p := permissions.Request(
116				permission.CreatePermissionRequest{
117					SessionID:   sessionID,
118					Path:        fsext.PathOrPrefix(filePath, workingDir),
119					ToolCallID:  call.ID,
120					ToolName:    WriteToolName,
121					Action:      "write",
122					Description: fmt.Sprintf("Create file %s", filePath),
123					Params: WritePermissionsParams{
124						FilePath:   filePath,
125						OldContent: oldContent,
126						NewContent: params.Content,
127					},
128				},
129			)
130			if !p {
131				return ai.ToolResponse{}, permission.ErrorPermissionDenied
132			}
133
134			err = os.WriteFile(filePath, []byte(params.Content), 0o644)
135			if err != nil {
136				return ai.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
137			}
138
139			// Check if file exists in history
140			file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
141			if err != nil {
142				_, err = files.Create(ctx, sessionID, filePath, oldContent)
143				if err != nil {
144					// Log error but don't fail the operation
145					return ai.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
146				}
147			}
148			if file.Content != oldContent {
149				// User Manually changed the content store an intermediate version
150				_, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
151				if err != nil {
152					slog.Debug("Error creating file history version", "error", err)
153				}
154			}
155			// Store the new version
156			_, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
157			if err != nil {
158				slog.Debug("Error creating file history version", "error", err)
159			}
160
161			recordFileWrite(filePath)
162			recordFileRead(filePath)
163
164			notifyLSPs(ctx, lspClients, params.FilePath)
165
166			result := fmt.Sprintf("File successfully written: %s", filePath)
167			result = fmt.Sprintf("<result>\n%s\n</result>", result)
168			result += getDiagnostics(filePath, lspClients)
169			return ai.WithResponseMetadata(ai.NewTextResponse(result),
170				WriteResponseMetadata{
171					Diff:      diff,
172					Additions: additions,
173					Removals:  removals,
174				},
175			), nil
176		})
177}