1package tools
  2
  3import (
  4	"runtime"
  5	"slices"
  6	"testing"
  7)
  8
  9func TestGetSafeReadOnlyCommands(t *testing.T) {
 10	commands := getSafeReadOnlyCommands()
 11
 12	// Check that we have some commands
 13	if len(commands) == 0 {
 14		t.Fatal("Expected some safe commands, got none")
 15	}
 16
 17	// Check for cross-platform commands that should always be present
 18	crossPlatformCommands := []string{"echo", "hostname", "whoami", "git status", "go version"}
 19	for _, cmd := range crossPlatformCommands {
 20		found := slices.Contains(commands, cmd)
 21		if !found {
 22			t.Errorf("Expected cross-platform command %q to be in safe commands", cmd)
 23		}
 24	}
 25
 26	if runtime.GOOS == "windows" {
 27		// Check for Windows-specific commands
 28		windowsCommands := []string{"dir", "type", "Get-Process"}
 29		for _, cmd := range windowsCommands {
 30			found := slices.Contains(commands, cmd)
 31			if !found {
 32				t.Errorf("Expected Windows command %q to be in safe commands on Windows", cmd)
 33			}
 34		}
 35
 36		// Check that Unix commands are NOT present on Windows
 37		unixCommands := []string{"ls", "pwd", "ps"}
 38		for _, cmd := range unixCommands {
 39			found := slices.Contains(commands, cmd)
 40			if found {
 41				t.Errorf("Unix command %q should not be in safe commands on Windows", cmd)
 42			}
 43		}
 44	} else {
 45		// Check for Unix-specific commands
 46		unixCommands := []string{"ls", "pwd", "ps"}
 47		for _, cmd := range unixCommands {
 48			found := slices.Contains(commands, cmd)
 49			if !found {
 50				t.Errorf("Expected Unix command %q to be in safe commands on Unix", cmd)
 51			}
 52		}
 53
 54		// Check that Windows-specific commands are NOT present on Unix
 55		windowsOnlyCommands := []string{"dir", "Get-Process", "systeminfo"}
 56		for _, cmd := range windowsOnlyCommands {
 57			found := slices.Contains(commands, cmd)
 58			if found {
 59				t.Errorf("Windows-only command %q should not be in safe commands on Unix", cmd)
 60			}
 61		}
 62	}
 63}
 64
 65func TestPlatformSpecificSafeCommands(t *testing.T) {
 66	// Test that the function returns different results on different platforms
 67	commands := getSafeReadOnlyCommands()
 68
 69	hasWindowsCommands := false
 70	hasUnixCommands := false
 71
 72	for _, cmd := range commands {
 73		if cmd == "dir" || cmd == "Get-Process" || cmd == "systeminfo" {
 74			hasWindowsCommands = true
 75		}
 76		if cmd == "ls" || cmd == "ps" || cmd == "df" {
 77			hasUnixCommands = true
 78		}
 79	}
 80
 81	if runtime.GOOS == "windows" {
 82		if !hasWindowsCommands {
 83			t.Error("Expected Windows commands on Windows platform")
 84		}
 85		if hasUnixCommands {
 86			t.Error("Did not expect Unix commands on Windows platform")
 87		}
 88	} else {
 89		if hasWindowsCommands {
 90			t.Error("Did not expect Windows-only commands on Unix platform")
 91		}
 92		if !hasUnixCommands {
 93			t.Error("Expected Unix commands on Unix platform")
 94		}
 95	}
 96}
 97
 98func TestValidateCommand(t *testing.T) {
 99	tests := []struct {
100		name        string
101		command     string
102		shouldError bool
103	}{
104		// Commands that should be blocked
105		{
106			name:        "direct sudo",
107			command:     "sudo ls",
108			shouldError: true,
109		},
110		{
111			name:        "sudo in script",
112			command:     "bash -c 'sudo ls'",
113			shouldError: true,
114		},
115		{
116			name:        "sudo in command substitution",
117			command:     "$(sudo whoami)",
118			shouldError: true,
119		},
120		{
121			name:        "sudo in echo command substitution",
122			command:     "echo $(sudo id)",
123			shouldError: true,
124		},
125		{
126			name:        "sudo in command chain",
127			command:     "ls && sudo rm file",
128			shouldError: true,
129		},
130		{
131			name:        "sudo in if statement",
132			command:     "if true; then sudo ls; fi",
133			shouldError: true,
134		},
135		{
136			name:        "sudo in for loop",
137			command:     "for i in 1; do sudo echo $i; done",
138			shouldError: true,
139		},
140		{
141			name:        "direct curl",
142			command:     "curl http://example.com",
143			shouldError: true,
144		},
145		{
146			name:        "curl in script",
147			command:     "bash -c 'curl malicious.com'",
148			shouldError: true,
149		},
150		{
151			name:        "wget command",
152			command:     "wget http://example.com",
153			shouldError: true,
154		},
155		{
156			name:        "nc command",
157			command:     "nc -l 8080",
158			shouldError: true,
159		},
160		// Commands that should be allowed
161		{
162			name:        "simple ls",
163			command:     "ls -la",
164			shouldError: false,
165		},
166		{
167			name:        "echo command",
168			command:     "echo hello",
169			shouldError: false,
170		},
171		{
172			name:        "git status",
173			command:     "git status",
174			shouldError: false,
175		},
176		{
177			name:        "go build",
178			command:     "go build",
179			shouldError: false,
180		},
181		{
182			name:        "sudo as literal text",
183			command:     "echo 'sudo is just text here'",
184			shouldError: false,
185		},
186		{
187			name:        "complex allowed command",
188			command:     "find . -name '*.go' | head -10",
189			shouldError: false,
190		},
191		{
192			name:        "command with environment variables",
193			command:     "FOO=bar go test",
194			shouldError: false,
195		},
196	}
197
198	for _, tt := range tests {
199		t.Run(tt.name, func(t *testing.T) {
200			err := validateCommand(tt.command)
201			if tt.shouldError && err == nil {
202				t.Errorf("Expected error for command %q, but got none", tt.command)
203			}
204			if !tt.shouldError && err != nil {
205				t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
206			}
207		})
208	}
209}
210
211func TestContainsBannedCommand(t *testing.T) {
212	// Test the helper functions directly with some edge cases
213	tests := []struct {
214		name        string
215		command     string
216		shouldError bool
217	}{
218		{
219			name:        "nested command substitution",
220			command:     "echo $(echo $(sudo id))",
221			shouldError: true,
222		},
223		{
224			name:        "subshell with banned command",
225			command:     "(sudo ls)",
226			shouldError: true,
227		},
228		{
229			name:        "case statement with banned command",
230			command:     "case $1 in start) sudo systemctl start service ;; esac",
231			shouldError: true,
232		},
233		{
234			name:        "while loop with banned command",
235			command:     "while true; do sudo echo test; done",
236			shouldError: true,
237		},
238		{
239			name:        "function with banned command",
240			command:     "function test() { sudo ls; }",
241			shouldError: true,
242		},
243		{
244			name:        "complex valid command",
245			command:     "if [ -f file ]; then echo exists; else echo missing; fi",
246			shouldError: false,
247		},
248	}
249
250	for _, tt := range tests {
251		t.Run(tt.name, func(t *testing.T) {
252			err := validateCommand(tt.command)
253			if tt.shouldError && err == nil {
254				t.Errorf("Expected error for command %q, but got none", tt.command)
255			}
256			if !tt.shouldError && err != nil {
257				t.Errorf("Expected no error for command %q, but got: %v", tt.command, err)
258			}
259		})
260	}
261}