1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "os"
8 "path/filepath"
9 "strings"
10 "time"
11
12 "github.com/cloudwego/eino/components/tool"
13 "github.com/cloudwego/eino/schema"
14 "github.com/kujtimiihoxha/termai/internal/permission"
15)
16
17type writeTool struct {
18 workingDir string
19}
20
21const (
22 WriteToolName = "write"
23)
24
25type WriteParams struct {
26 FilePath string `json:"file_path"`
27 Content string `json:"content"`
28}
29
30func (b *writeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
31 return &schema.ToolInfo{
32 Name: WriteToolName,
33 Desc: "Write a file to the local filesystem. Overwrites the existing file if there is one.\n\nBefore using this tool:\n\n1. Use the ReadFile tool to understand the file's contents and context\n\n2. Directory Verification (only applicable when creating new files):\n - Use the LS tool to verify the parent directory exists and is the correct location",
34 ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
35 "file_path": {
36 Type: "string",
37 Desc: "The absolute path to the file to write (must be absolute, not relative)",
38 Required: true,
39 },
40 "content": {
41 Type: "string",
42 Desc: "The content to write to the file",
43 Required: true,
44 },
45 }),
46 }, nil
47}
48
49func (b *writeTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
50 var params WriteParams
51 if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
52 return "", fmt.Errorf("failed to parse parameters: %w", err)
53 }
54
55 if params.FilePath == "" {
56 return "file_path is required", nil
57 }
58
59 if !filepath.IsAbs(params.FilePath) {
60 return fmt.Sprintf("file path must be absolute, got: %s", params.FilePath), nil
61 }
62
63 // fileExists := false
64 // oldContent := ""
65 fileInfo, err := os.Stat(params.FilePath)
66 if err == nil {
67 if fileInfo.IsDir() {
68 return fmt.Sprintf("path is a directory, not a file: %s", params.FilePath), nil
69 }
70
71 modTime := fileInfo.ModTime()
72 lastRead := getLastReadTime(params.FilePath)
73 if modTime.After(lastRead) {
74 return fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
75 params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)), nil
76 }
77
78 // oldContentBytes, readErr := os.ReadFile(params.FilePath)
79 // if readErr != nil {
80 // oldContent = string(oldContentBytes)
81 // }
82 } else if !os.IsNotExist(err) {
83 return fmt.Sprintf("failed to access file: %s", err), nil
84 }
85
86 p := permission.Default.Request(
87 permission.CreatePermissionRequest{
88 Path: b.workingDir,
89 ToolName: WriteToolName,
90 Action: "write",
91 Description: fmt.Sprintf("Write to file %s", params.FilePath),
92 Params: map[string]interface{}{
93 "file_path": params.FilePath,
94 "contnet": params.Content,
95 },
96 },
97 )
98 if !p {
99 return "", fmt.Errorf("permission denied")
100 }
101 dir := filepath.Dir(params.FilePath)
102 if err = os.MkdirAll(dir, 0o755); err != nil {
103 return fmt.Sprintf("failed to create parent directories: %s", err), nil
104 }
105
106 err = os.WriteFile(params.FilePath, []byte(params.Content), 0o644)
107 if err != nil {
108 return fmt.Sprintf("failed to write file: %s", err), nil
109 }
110
111 recordFileWrite(params.FilePath)
112
113 output := "File written: " + params.FilePath
114
115 // if fileExists && oldContent != params.Content {
116 // output = generateSimpleDiff(oldContent, params.Content)
117 // }
118
119 return output, nil
120}
121
122func generateSimpleDiff(oldContent, newContent string) string {
123 if oldContent == newContent {
124 return "[No changes]"
125 }
126
127 oldLines := strings.Split(oldContent, "\n")
128 newLines := strings.Split(newContent, "\n")
129
130 var diffBuilder strings.Builder
131 diffBuilder.WriteString(fmt.Sprintf("@@ -%d,+%d @@\n", len(oldLines), len(newLines)))
132
133 maxLines := max(len(oldLines), len(newLines))
134 for i := range maxLines {
135 oldLine := ""
136 newLine := ""
137
138 if i < len(oldLines) {
139 oldLine = oldLines[i]
140 }
141
142 if i < len(newLines) {
143 newLine = newLines[i]
144 }
145
146 if oldLine != newLine {
147 if i < len(oldLines) {
148 diffBuilder.WriteString(fmt.Sprintf("- %s\n", oldLine))
149 }
150 if i < len(newLines) {
151 diffBuilder.WriteString(fmt.Sprintf("+ %s\n", newLine))
152 }
153 } else {
154 diffBuilder.WriteString(fmt.Sprintf(" %s\n", oldLine))
155 }
156 }
157
158 return diffBuilder.String()
159}
160
161func NewWriteTool(workingDir string) tool.InvokableTool {
162 return &writeTool{
163 workingDir: workingDir,
164 }
165}