@@ -275,30 +275,29 @@ Important:
func blockFuncs() []shell.BlockFunc {
return []shell.BlockFunc{
shell.CommandsBlocker(bannedCommands),
- shell.ArgumentsBlocker([][]string{
- // System package managers
- {"apk", "add"},
- {"apt", "install"},
- {"apt-get", "install"},
- {"dnf", "install"},
- {"pacman", "-S"},
- {"pkg", "install"},
- {"yum", "install"},
- {"zypper", "install"},
-
- // Language-specific package managers
- {"brew", "install"},
- {"cargo", "install"},
- {"gem", "install"},
- {"go", "install"},
- {"npm", "install", "-g"},
- {"npm", "install", "--global"},
- {"pip", "install", "--user"},
- {"pip3", "install", "--user"},
- {"pnpm", "add", "-g"},
- {"pnpm", "add", "--global"},
- {"yarn", "global", "add"},
- }),
+
+ // System package managers
+ shell.ArgumentsBlocker("apk", []string{"add"}, nil),
+ shell.ArgumentsBlocker("apt", []string{"install"}, nil),
+ shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
+ shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
+ shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
+ shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
+ shell.ArgumentsBlocker("yum", []string{"install"}, nil),
+ shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
+
+ // Language-specific package managers
+ shell.ArgumentsBlocker("brew", []string{"install"}, nil),
+ shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
+ shell.ArgumentsBlocker("gem", []string{"install"}, nil),
+ shell.ArgumentsBlocker("go", []string{"install"}, nil),
+ shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
+ shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
+ shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
+ shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
+ shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
+ shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
+ shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
}
}
@@ -4,6 +4,8 @@ import (
"context"
"strings"
"testing"
+
+ "github.com/stretchr/testify/require"
)
func TestCommandBlocking(t *testing.T) {
@@ -56,10 +58,7 @@ func TestCommandBlocking(t *testing.T) {
{
name: "block npm global install with -g",
blockFuncs: []BlockFunc{
- ArgumentsBlocker([][]string{
- {"npm", "install", "-g"},
- {"npm", "install", "--global"},
- }),
+ ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
},
command: "npm install -g typescript",
shouldBlock: true,
@@ -67,10 +66,7 @@ func TestCommandBlocking(t *testing.T) {
{
name: "block npm global install with --global",
blockFuncs: []BlockFunc{
- ArgumentsBlocker([][]string{
- {"npm", "install", "-g"},
- {"npm", "install", "--global"},
- }),
+ ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
},
command: "npm install --global typescript",
shouldBlock: true,
@@ -78,10 +74,8 @@ func TestCommandBlocking(t *testing.T) {
{
name: "allow npm local install",
blockFuncs: []BlockFunc{
- ArgumentsBlocker([][]string{
- {"npm", "install", "-g"},
- {"npm", "install", "--global"},
- }),
+ ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
+ ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
},
command: "npm install typescript",
shouldBlock: false,
@@ -116,3 +110,232 @@ func TestCommandBlocking(t *testing.T) {
})
}
}
+
+func TestArgumentsBlocker(t *testing.T) {
+ tests := []struct {
+ name string
+ cmd string
+ args []string
+ flags []string
+ input []string
+ shouldBlock bool
+ }{
+ // Basic command blocking
+ {
+ name: "block exact command match",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: nil,
+ input: []string{"npm", "install", "package"},
+ shouldBlock: true,
+ },
+ {
+ name: "allow different command",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: nil,
+ input: []string{"yarn", "install", "package"},
+ shouldBlock: false,
+ },
+ {
+ name: "allow different subcommand",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: nil,
+ input: []string{"npm", "list"},
+ shouldBlock: false,
+ },
+
+ // Flag-based blocking
+ {
+ name: "block with single flag",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: []string{"-g"},
+ input: []string{"npm", "install", "-g", "typescript"},
+ shouldBlock: true,
+ },
+ {
+ name: "block with flag in different position",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: []string{"-g"},
+ input: []string{"npm", "install", "typescript", "-g"},
+ shouldBlock: true,
+ },
+ {
+ name: "allow without required flag",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: []string{"-g"},
+ input: []string{"npm", "install", "typescript"},
+ shouldBlock: false,
+ },
+ {
+ name: "block with multiple flags",
+ cmd: "pip",
+ args: []string{"install"},
+ flags: []string{"--user"},
+ input: []string{"pip", "install", "--user", "--upgrade", "package"},
+ shouldBlock: true,
+ },
+
+ // Complex argument patterns
+ {
+ name: "block multi-arg subcommand",
+ cmd: "yarn",
+ args: []string{"global", "add"},
+ flags: nil,
+ input: []string{"yarn", "global", "add", "typescript"},
+ shouldBlock: true,
+ },
+ {
+ name: "allow partial multi-arg match",
+ cmd: "yarn",
+ args: []string{"global", "add"},
+ flags: nil,
+ input: []string{"yarn", "global", "list"},
+ shouldBlock: false,
+ },
+
+ // Edge cases
+ {
+ name: "handle empty input",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: nil,
+ input: []string{},
+ shouldBlock: false,
+ },
+ {
+ name: "handle command only",
+ cmd: "npm",
+ args: []string{"install"},
+ flags: nil,
+ input: []string{"npm"},
+ shouldBlock: false,
+ },
+ {
+ name: "block pacman with -S flag",
+ cmd: "pacman",
+ args: nil,
+ flags: []string{"-S"},
+ input: []string{"pacman", "-S", "package"},
+ shouldBlock: true,
+ },
+ {
+ name: "allow pacman without -S flag",
+ cmd: "pacman",
+ args: nil,
+ flags: []string{"-S"},
+ input: []string{"pacman", "-Q", "package"},
+ shouldBlock: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ blocker := ArgumentsBlocker(tt.cmd, tt.args, tt.flags)
+ result := blocker(tt.input)
+ require.Equal(t, tt.shouldBlock, result,
+ "Expected block=%v for input %v", tt.shouldBlock, tt.input)
+ })
+ }
+}
+
+func TestCommandsBlocker(t *testing.T) {
+ tests := []struct {
+ name string
+ banned []string
+ input []string
+ shouldBlock bool
+ }{
+ {
+ name: "block single banned command",
+ banned: []string{"curl"},
+ input: []string{"curl", "https://example.com"},
+ shouldBlock: true,
+ },
+ {
+ name: "allow non-banned command",
+ banned: []string{"curl", "wget"},
+ input: []string{"echo", "hello"},
+ shouldBlock: false,
+ },
+ {
+ name: "block from multiple banned",
+ banned: []string{"curl", "wget", "nc"},
+ input: []string{"wget", "https://example.com"},
+ shouldBlock: true,
+ },
+ {
+ name: "handle empty input",
+ banned: []string{"curl"},
+ input: []string{},
+ shouldBlock: false,
+ },
+ {
+ name: "case sensitive matching",
+ banned: []string{"curl"},
+ input: []string{"CURL", "https://example.com"},
+ shouldBlock: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ blocker := CommandsBlocker(tt.banned)
+ result := blocker(tt.input)
+ require.Equal(t, tt.shouldBlock, result,
+ "Expected block=%v for input %v", tt.shouldBlock, tt.input)
+ })
+ }
+}
+
+func TestSplitArgsFlags(t *testing.T) {
+ tests := []struct {
+ name string
+ input []string
+ wantArgs []string
+ wantFlags []string
+ }{
+ {
+ name: "only args",
+ input: []string{"install", "package", "another"},
+ wantArgs: []string{"install", "package", "another"},
+ wantFlags: []string{},
+ },
+ {
+ name: "only flags",
+ input: []string{"-g", "--verbose", "-f"},
+ wantArgs: []string{},
+ wantFlags: []string{"-g", "--verbose", "-f"},
+ },
+ {
+ name: "mixed args and flags",
+ input: []string{"install", "-g", "package", "--verbose"},
+ wantArgs: []string{"install", "package"},
+ wantFlags: []string{"-g", "--verbose"},
+ },
+ {
+ name: "empty input",
+ input: []string{},
+ wantArgs: []string{},
+ wantFlags: []string{},
+ },
+ {
+ name: "single dash flag",
+ input: []string{"-S", "package"},
+ wantArgs: []string{"package"},
+ wantFlags: []string{"-S"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ args, flags := splitArgsFlags(tt.input)
+ require.Equal(t, tt.wantArgs, args, "args mismatch")
+ require.Equal(t, tt.wantFlags, flags, "flags mismatch")
+ })
+ }
+}