diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 4b4a1ba526ae1af60f2c4a2b9307afe37eca2e58..add62210f7179a15912959a07c64f5da333b582d 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -275,30 +275,29 @@ Important: func blockFuncs() []shell.BlockFunc { return []shell.BlockFunc{ shell.CommandsBlocker(bannedCommands), - shell.ArgumentsBlocker([][]string{ - // System package managers - {"apk", "add"}, - {"apt", "install"}, - {"apt-get", "install"}, - {"dnf", "install"}, - {"pacman", "-S"}, - {"pkg", "install"}, - {"yum", "install"}, - {"zypper", "install"}, - - // Language-specific package managers - {"brew", "install"}, - {"cargo", "install"}, - {"gem", "install"}, - {"go", "install"}, - {"npm", "install", "-g"}, - {"npm", "install", "--global"}, - {"pip", "install", "--user"}, - {"pip3", "install", "--user"}, - {"pnpm", "add", "-g"}, - {"pnpm", "add", "--global"}, - {"yarn", "global", "add"}, - }), + + // System package managers + shell.ArgumentsBlocker("apk", []string{"add"}, nil), + shell.ArgumentsBlocker("apt", []string{"install"}, nil), + shell.ArgumentsBlocker("apt-get", []string{"install"}, nil), + shell.ArgumentsBlocker("dnf", []string{"install"}, nil), + shell.ArgumentsBlocker("pacman", nil, []string{"-S"}), + shell.ArgumentsBlocker("pkg", []string{"install"}, nil), + shell.ArgumentsBlocker("yum", []string{"install"}, nil), + shell.ArgumentsBlocker("zypper", []string{"install"}, nil), + + // Language-specific package managers + shell.ArgumentsBlocker("brew", []string{"install"}, nil), + shell.ArgumentsBlocker("cargo", []string{"install"}, nil), + shell.ArgumentsBlocker("gem", []string{"install"}, nil), + shell.ArgumentsBlocker("go", []string{"install"}, nil), + shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}), + shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}), + shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}), + shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}), + shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}), + shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}), + shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil), } } diff --git a/internal/shell/command_block_test.go b/internal/shell/command_block_test.go index 0d29b61e15091d9102d69cef8b84b610e98365b6..99eb5256446f604bb473f3130fd24950d1e058b3 100644 --- a/internal/shell/command_block_test.go +++ b/internal/shell/command_block_test.go @@ -4,6 +4,8 @@ import ( "context" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestCommandBlocking(t *testing.T) { @@ -56,10 +58,7 @@ func TestCommandBlocking(t *testing.T) { { name: "block npm global install with -g", blockFuncs: []BlockFunc{ - ArgumentsBlocker([][]string{ - {"npm", "install", "-g"}, - {"npm", "install", "--global"}, - }), + ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}), }, command: "npm install -g typescript", shouldBlock: true, @@ -67,10 +66,7 @@ func TestCommandBlocking(t *testing.T) { { name: "block npm global install with --global", blockFuncs: []BlockFunc{ - ArgumentsBlocker([][]string{ - {"npm", "install", "-g"}, - {"npm", "install", "--global"}, - }), + ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}), }, command: "npm install --global typescript", shouldBlock: true, @@ -78,10 +74,8 @@ func TestCommandBlocking(t *testing.T) { { name: "allow npm local install", blockFuncs: []BlockFunc{ - ArgumentsBlocker([][]string{ - {"npm", "install", "-g"}, - {"npm", "install", "--global"}, - }), + ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}), + ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}), }, command: "npm install typescript", shouldBlock: false, @@ -116,3 +110,232 @@ func TestCommandBlocking(t *testing.T) { }) } } + +func TestArgumentsBlocker(t *testing.T) { + tests := []struct { + name string + cmd string + args []string + flags []string + input []string + shouldBlock bool + }{ + // Basic command blocking + { + name: "block exact command match", + cmd: "npm", + args: []string{"install"}, + flags: nil, + input: []string{"npm", "install", "package"}, + shouldBlock: true, + }, + { + name: "allow different command", + cmd: "npm", + args: []string{"install"}, + flags: nil, + input: []string{"yarn", "install", "package"}, + shouldBlock: false, + }, + { + name: "allow different subcommand", + cmd: "npm", + args: []string{"install"}, + flags: nil, + input: []string{"npm", "list"}, + shouldBlock: false, + }, + + // Flag-based blocking + { + name: "block with single flag", + cmd: "npm", + args: []string{"install"}, + flags: []string{"-g"}, + input: []string{"npm", "install", "-g", "typescript"}, + shouldBlock: true, + }, + { + name: "block with flag in different position", + cmd: "npm", + args: []string{"install"}, + flags: []string{"-g"}, + input: []string{"npm", "install", "typescript", "-g"}, + shouldBlock: true, + }, + { + name: "allow without required flag", + cmd: "npm", + args: []string{"install"}, + flags: []string{"-g"}, + input: []string{"npm", "install", "typescript"}, + shouldBlock: false, + }, + { + name: "block with multiple flags", + cmd: "pip", + args: []string{"install"}, + flags: []string{"--user"}, + input: []string{"pip", "install", "--user", "--upgrade", "package"}, + shouldBlock: true, + }, + + // Complex argument patterns + { + name: "block multi-arg subcommand", + cmd: "yarn", + args: []string{"global", "add"}, + flags: nil, + input: []string{"yarn", "global", "add", "typescript"}, + shouldBlock: true, + }, + { + name: "allow partial multi-arg match", + cmd: "yarn", + args: []string{"global", "add"}, + flags: nil, + input: []string{"yarn", "global", "list"}, + shouldBlock: false, + }, + + // Edge cases + { + name: "handle empty input", + cmd: "npm", + args: []string{"install"}, + flags: nil, + input: []string{}, + shouldBlock: false, + }, + { + name: "handle command only", + cmd: "npm", + args: []string{"install"}, + flags: nil, + input: []string{"npm"}, + shouldBlock: false, + }, + { + name: "block pacman with -S flag", + cmd: "pacman", + args: nil, + flags: []string{"-S"}, + input: []string{"pacman", "-S", "package"}, + shouldBlock: true, + }, + { + name: "allow pacman without -S flag", + cmd: "pacman", + args: nil, + flags: []string{"-S"}, + input: []string{"pacman", "-Q", "package"}, + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + blocker := ArgumentsBlocker(tt.cmd, tt.args, tt.flags) + result := blocker(tt.input) + require.Equal(t, tt.shouldBlock, result, + "Expected block=%v for input %v", tt.shouldBlock, tt.input) + }) + } +} + +func TestCommandsBlocker(t *testing.T) { + tests := []struct { + name string + banned []string + input []string + shouldBlock bool + }{ + { + name: "block single banned command", + banned: []string{"curl"}, + input: []string{"curl", "https://example.com"}, + shouldBlock: true, + }, + { + name: "allow non-banned command", + banned: []string{"curl", "wget"}, + input: []string{"echo", "hello"}, + shouldBlock: false, + }, + { + name: "block from multiple banned", + banned: []string{"curl", "wget", "nc"}, + input: []string{"wget", "https://example.com"}, + shouldBlock: true, + }, + { + name: "handle empty input", + banned: []string{"curl"}, + input: []string{}, + shouldBlock: false, + }, + { + name: "case sensitive matching", + banned: []string{"curl"}, + input: []string{"CURL", "https://example.com"}, + shouldBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + blocker := CommandsBlocker(tt.banned) + result := blocker(tt.input) + require.Equal(t, tt.shouldBlock, result, + "Expected block=%v for input %v", tt.shouldBlock, tt.input) + }) + } +} + +func TestSplitArgsFlags(t *testing.T) { + tests := []struct { + name string + input []string + wantArgs []string + wantFlags []string + }{ + { + name: "only args", + input: []string{"install", "package", "another"}, + wantArgs: []string{"install", "package", "another"}, + wantFlags: []string{}, + }, + { + name: "only flags", + input: []string{"-g", "--verbose", "-f"}, + wantArgs: []string{}, + wantFlags: []string{"-g", "--verbose", "-f"}, + }, + { + name: "mixed args and flags", + input: []string{"install", "-g", "package", "--verbose"}, + wantArgs: []string{"install", "package"}, + wantFlags: []string{"-g", "--verbose"}, + }, + { + name: "empty input", + input: []string{}, + wantArgs: []string{}, + wantFlags: []string{}, + }, + { + name: "single dash flag", + input: []string{"-S", "package"}, + wantArgs: []string{"package"}, + wantFlags: []string{"-S"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args, flags := splitArgsFlags(tt.input) + require.Equal(t, tt.wantArgs, args, "args mismatch") + require.Equal(t, tt.wantFlags, flags, "flags mismatch") + }) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 9798eaa8c417fc54e5d0213547f6c9b644f70600..73dbfbffcc65cffba3e96443076dd80e940d7742 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -16,9 +16,11 @@ import ( "errors" "fmt" "os" + "slices" "strings" "sync" + "github.com/charmbracelet/crush/internal/slicesext" "mvdan.cc/sh/moreinterp/coreutils" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/interp" @@ -171,25 +173,36 @@ func CommandsBlocker(cmds []string) BlockFunc { } } -// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands -func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc { - return func(args []string) bool { - for _, blocked := range blockedSubCommands { - if len(args) >= len(blocked) { - match := true - for i, part := range blocked { - if args[i] != part { - match = false - break - } - } - if match { - return true - } - } +// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand +func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc { + return func(parts []string) bool { + if len(parts) == 0 || parts[0] != cmd { + return false + } + + argParts, flagParts := splitArgsFlags(parts[1:]) + if len(argParts) < len(args) || len(flagParts) < len(flags) { + return false + } + + argsMatch := slices.Equal(argParts[:len(args)], args) + flagsMatch := slicesext.IsSubset(flags, flagParts) + + return argsMatch && flagsMatch + } +} + +func splitArgsFlags(parts []string) (args []string, flags []string) { + args = make([]string, 0, len(parts)) + flags = make([]string, 0, len(parts)) + for _, part := range parts { + if strings.HasPrefix(part, "-") { + flags = append(flags, part) + } else { + args = append(args, part) } - return false } + return } func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {