write.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"strings"
 10	"time"
 11
 12	"github.com/cloudwego/eino/components/tool"
 13	"github.com/cloudwego/eino/schema"
 14	"github.com/kujtimiihoxha/termai/internal/permission"
 15)
 16
 17type writeTool struct {
 18	workingDir string
 19}
 20
 21const (
 22	WriteToolName = "write"
 23)
 24
 25type WriteParams struct {
 26	FilePath string `json:"file_path"`
 27	Content  string `json:"content"`
 28}
 29
 30func (b *writeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
 31	return &schema.ToolInfo{
 32		Name: WriteToolName,
 33		Desc: "Write a file to the local filesystem. Overwrites the existing file if there is one.\n\nBefore using this tool:\n\n1. Use the ReadFile tool to understand the file's contents and context\n\n2. Directory Verification (only applicable when creating new files):\n   - Use the LS tool to verify the parent directory exists and is the correct location",
 34		ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
 35			"file_path": {
 36				Type:     "string",
 37				Desc:     "The absolute path to the file to write (must be absolute, not relative)",
 38				Required: true,
 39			},
 40			"content": {
 41				Type:     "string",
 42				Desc:     "The content to write to the file",
 43				Required: true,
 44			},
 45		}),
 46	}, nil
 47}
 48
 49func (b *writeTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
 50	var params WriteParams
 51	if err := json.Unmarshal([]byte(args), &params); err != nil {
 52		return "", fmt.Errorf("failed to parse parameters: %w", err)
 53	}
 54
 55	if params.FilePath == "" {
 56		return "file_path is required", nil
 57	}
 58
 59	if !filepath.IsAbs(params.FilePath) {
 60		return fmt.Sprintf("file path must be absolute, got: %s", params.FilePath), nil
 61	}
 62
 63	// fileExists := false
 64	// oldContent := ""
 65	fileInfo, err := os.Stat(params.FilePath)
 66	if err == nil {
 67		if fileInfo.IsDir() {
 68			return fmt.Sprintf("path is a directory, not a file: %s", params.FilePath), nil
 69		}
 70
 71		modTime := fileInfo.ModTime()
 72		lastRead := getLastReadTime(params.FilePath)
 73		if modTime.After(lastRead) {
 74			return fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
 75				params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)), nil
 76		}
 77
 78		// oldContentBytes, readErr := os.ReadFile(params.FilePath)
 79		// if readErr != nil {
 80		// 	oldContent = string(oldContentBytes)
 81		// }
 82	} else if !os.IsNotExist(err) {
 83		return fmt.Sprintf("failed to access file: %s", err), nil
 84	}
 85
 86	p := permission.Default.Request(
 87		permission.CreatePermissionRequest{
 88			Path:        b.workingDir,
 89			ToolName:    WriteToolName,
 90			Action:      "write",
 91			Description: fmt.Sprintf("Write to file %s", params.FilePath),
 92			Params: map[string]interface{}{
 93				"file_path": params.FilePath,
 94				"contnet":   params.Content,
 95			},
 96		},
 97	)
 98	if !p {
 99		return "", fmt.Errorf("permission denied")
100	}
101	dir := filepath.Dir(params.FilePath)
102	if err = os.MkdirAll(dir, 0o755); err != nil {
103		return fmt.Sprintf("failed to create parent directories: %s", err), nil
104	}
105
106	err = os.WriteFile(params.FilePath, []byte(params.Content), 0o644)
107	if err != nil {
108		return fmt.Sprintf("failed to write file: %s", err), nil
109	}
110
111	recordFileWrite(params.FilePath)
112
113	output := "File written: " + params.FilePath
114
115	// if fileExists && oldContent != params.Content {
116	// 	output = generateSimpleDiff(oldContent, params.Content)
117	// }
118
119	return output, nil
120}
121
122func generateSimpleDiff(oldContent, newContent string) string {
123	if oldContent == newContent {
124		return "[No changes]"
125	}
126
127	oldLines := strings.Split(oldContent, "\n")
128	newLines := strings.Split(newContent, "\n")
129
130	var diffBuilder strings.Builder
131	diffBuilder.WriteString(fmt.Sprintf("@@ -%d,+%d @@\n", len(oldLines), len(newLines)))
132
133	maxLines := max(len(oldLines), len(newLines))
134	for i := range maxLines {
135		oldLine := ""
136		newLine := ""
137
138		if i < len(oldLines) {
139			oldLine = oldLines[i]
140		}
141
142		if i < len(newLines) {
143			newLine = newLines[i]
144		}
145
146		if oldLine != newLine {
147			if i < len(oldLines) {
148				diffBuilder.WriteString(fmt.Sprintf("- %s\n", oldLine))
149			}
150			if i < len(newLines) {
151				diffBuilder.WriteString(fmt.Sprintf("+ %s\n", newLine))
152			}
153		} else {
154			diffBuilder.WriteString(fmt.Sprintf("  %s\n", oldLine))
155		}
156	}
157
158	return diffBuilder.String()
159}
160
161func NewWriteTool(workingDir string) tool.InvokableTool {
162	return &writeTool{
163		workingDir: workingDir,
164	}
165}