1package tools
2
3import (
4 "runtime"
5 "slices"
6 "testing"
7)
8
9func TestGetSafeReadOnlyCommands(t *testing.T) {
10 commands := getSafeReadOnlyCommands()
11
12 // Check that we have some commands
13 if len(commands) == 0 {
14 t.Fatal("Expected some safe commands, got none")
15 }
16
17 // Check for cross-platform commands that should always be present
18 crossPlatformCommands := []string{"echo", "hostname", "whoami", "git status", "go version"}
19 for _, cmd := range crossPlatformCommands {
20 found := slices.Contains(commands, cmd)
21 if !found {
22 t.Errorf("Expected cross-platform command %q to be in safe commands", cmd)
23 }
24 }
25
26 if runtime.GOOS == "windows" {
27 // Check for Windows-specific commands
28 windowsCommands := []string{"dir", "type", "Get-Process"}
29 for _, cmd := range windowsCommands {
30 found := slices.Contains(commands, cmd)
31 if !found {
32 t.Errorf("Expected Windows command %q to be in safe commands on Windows", cmd)
33 }
34 }
35
36 // Check that Unix commands are NOT present on Windows
37 unixCommands := []string{"ls", "pwd", "ps"}
38 for _, cmd := range unixCommands {
39 found := slices.Contains(commands, cmd)
40 if found {
41 t.Errorf("Unix command %q should not be in safe commands on Windows", cmd)
42 }
43 }
44 } else {
45 // Check for Unix-specific commands
46 unixCommands := []string{"ls", "pwd", "ps"}
47 for _, cmd := range unixCommands {
48 found := slices.Contains(commands, cmd)
49 if !found {
50 t.Errorf("Expected Unix command %q to be in safe commands on Unix", cmd)
51 }
52 }
53
54 // Check that Windows-specific commands are NOT present on Unix
55 windowsOnlyCommands := []string{"dir", "Get-Process", "systeminfo"}
56 for _, cmd := range windowsOnlyCommands {
57 found := slices.Contains(commands, cmd)
58 if found {
59 t.Errorf("Windows-only command %q should not be in safe commands on Unix", cmd)
60 }
61 }
62 }
63}
64
65func TestPlatformSpecificSafeCommands(t *testing.T) {
66 // Test that the function returns different results on different platforms
67 commands := getSafeReadOnlyCommands()
68
69 hasWindowsCommands := false
70 hasUnixCommands := false
71
72 for _, cmd := range commands {
73 if cmd == "dir" || cmd == "Get-Process" || cmd == "systeminfo" {
74 hasWindowsCommands = true
75 }
76 if cmd == "ls" || cmd == "ps" || cmd == "df" {
77 hasUnixCommands = true
78 }
79 }
80
81 if runtime.GOOS == "windows" {
82 if !hasWindowsCommands {
83 t.Error("Expected Windows commands on Windows platform")
84 }
85 if hasUnixCommands {
86 t.Error("Did not expect Unix commands on Windows platform")
87 }
88 } else {
89 if hasWindowsCommands {
90 t.Error("Did not expect Windows-only commands on Unix platform")
91 }
92 if !hasUnixCommands {
93 t.Error("Expected Unix commands on Unix platform")
94 }
95 }
96}
97
98func TestValidateCommand(t *testing.T) {
99 tests := []struct {
100 name string
101 command string
102 shouldError bool
103 }{
104 // Commands that should be blocked
105 {
106 name: "direct sudo",
107 command: "sudo ls",
108 shouldError: true,
109 },
110 {
111 name: "sudo in script",
112 command: "bash -c 'sudo ls'",
113 shouldError: true,
114 },
115 {
116 name: "sudo in command substitution",
117 command: "$(sudo whoami)",
118 shouldError: true,
119 },
120 {
121 name: "sudo in echo command substitution",
122 command: "echo $(sudo id)",
123 shouldError: true,
124 },
125 {
126 name: "sudo in command chain",
127 command: "ls && sudo rm file",
128 shouldError: true,
129 },
130 {
131 name: "sudo in if statement",
132 command: "if true; then sudo ls; fi",
133 shouldError: true,
134 },
135 {
136 name: "sudo in for loop",
137 command: "for i in 1; do sudo echo $i; done",
138 shouldError: true,
139 },
140 {
141 name: "direct curl",
142 command: "curl http://example.com",
143 shouldError: true,
144 },
145 {
146 name: "curl in script",
147 command: "bash -c 'curl malicious.com'",
148 shouldError: true,
149 },
150 {
151 name: "wget command",
152 command: "wget http://example.com",
153 shouldError: true,
154 },
155 {
156 name: "nc command",
157 command: "nc -l 8080",
158 shouldError: true,
159 },
160 // Commands that should be allowed
161 {
162 name: "simple ls",
163 command: "ls -la",
164 shouldError: false,
165 },
166 {
167 name: "echo command",
168 command: "echo hello",
169 shouldError: false,
170 },
171 {
172 name: "git status",
173 command: "git status",
174 shouldError: false,
175 },
176 {
177 name: "go build",
178 command: "go build",
179 shouldError: false,
180 },
181 {
182 name: "sudo as literal text",
183 command: "echo 'sudo is just text here'",
184 shouldError: false,
185 },
186 {
187 name: "complex allowed command",
188 command: "find . -name '*.go' | head -10",
189 shouldError: false,
190 },
191 {
192 name: "command with environment variables",
193 command: "FOO=bar go test",
194 shouldError: false,
195 },
196 }
197
198 for _, tt := range tests {
199 t.Run(tt.name, func(t *testing.T) {
200 err := validateCommand(tt.command)
201 if tt.shouldError && err == nil {
202 t.Errorf("Expected error for command %q, but got none", tt.command)
203 }
204 if !tt.shouldError && err != nil {
205 t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
206 }
207 })
208 }
209}
210
211func TestContainsBannedCommand(t *testing.T) {
212 // Test the helper functions directly with some edge cases
213 tests := []struct {
214 name string
215 command string
216 shouldError bool
217 }{
218 {
219 name: "nested command substitution",
220 command: "echo $(echo $(sudo id))",
221 shouldError: true,
222 },
223 {
224 name: "subshell with banned command",
225 command: "(sudo ls)",
226 shouldError: true,
227 },
228 {
229 name: "case statement with banned command",
230 command: "case $1 in start) sudo systemctl start service ;; esac",
231 shouldError: true,
232 },
233 {
234 name: "while loop with banned command",
235 command: "while true; do sudo echo test; done",
236 shouldError: true,
237 },
238 {
239 name: "function with banned command",
240 command: "function test() { sudo ls; }",
241 shouldError: true,
242 },
243 {
244 name: "complex valid command",
245 command: "if [ -f file ]; then echo exists; else echo missing; fi",
246 shouldError: false,
247 },
248 }
249
250 for _, tt := range tests {
251 t.Run(tt.name, func(t *testing.T) {
252 err := validateCommand(tt.command)
253 if tt.shouldError && err == nil {
254 t.Errorf("Expected error for command %q, but got none", tt.command)
255 }
256 if !tt.shouldError && err != nil {
257 t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
258 }
259 })
260 }
261}