fix(shell): refactor arguments blocker to check for flags in any position

Andrey Nering created

Change summary

internal/llm/tools/bash.go           |  47 ++--
internal/shell/command_block_test.go | 247 ++++++++++++++++++++++++++++-
internal/shell/shell.go              |  47 +++--
3 files changed, 288 insertions(+), 53 deletions(-)

Detailed changes

internal/llm/tools/bash.go 🔗

@@ -275,30 +275,29 @@ Important:
 func blockFuncs() []shell.BlockFunc {
 	return []shell.BlockFunc{
 		shell.CommandsBlocker(bannedCommands),
-		shell.ArgumentsBlocker([][]string{
-			// System package managers
-			{"apk", "add"},
-			{"apt", "install"},
-			{"apt-get", "install"},
-			{"dnf", "install"},
-			{"pacman", "-S"},
-			{"pkg", "install"},
-			{"yum", "install"},
-			{"zypper", "install"},
-
-			// Language-specific package managers
-			{"brew", "install"},
-			{"cargo", "install"},
-			{"gem", "install"},
-			{"go", "install"},
-			{"npm", "install", "-g"},
-			{"npm", "install", "--global"},
-			{"pip", "install", "--user"},
-			{"pip3", "install", "--user"},
-			{"pnpm", "add", "-g"},
-			{"pnpm", "add", "--global"},
-			{"yarn", "global", "add"},
-		}),
+
+		// System package managers
+		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
+		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
+		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
+		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
+		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
+		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
+		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
+		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
+
+		// Language-specific package managers
+		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
+		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
+		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
+		shell.ArgumentsBlocker("go", []string{"install"}, nil),
+		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
+		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
+		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
+		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
+		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
+		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
+		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
 	}
 }
 

internal/shell/command_block_test.go 🔗

@@ -4,6 +4,8 @@ import (
 	"context"
 	"strings"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestCommandBlocking(t *testing.T) {
@@ -56,10 +58,7 @@ func TestCommandBlocking(t *testing.T) {
 		{
 			name: "block npm global install with -g",
 			blockFuncs: []BlockFunc{
-				ArgumentsBlocker([][]string{
-					{"npm", "install", "-g"},
-					{"npm", "install", "--global"},
-				}),
+				ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
 			},
 			command:     "npm install -g typescript",
 			shouldBlock: true,
@@ -67,10 +66,7 @@ func TestCommandBlocking(t *testing.T) {
 		{
 			name: "block npm global install with --global",
 			blockFuncs: []BlockFunc{
-				ArgumentsBlocker([][]string{
-					{"npm", "install", "-g"},
-					{"npm", "install", "--global"},
-				}),
+				ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
 			},
 			command:     "npm install --global typescript",
 			shouldBlock: true,
@@ -78,10 +74,8 @@ func TestCommandBlocking(t *testing.T) {
 		{
 			name: "allow npm local install",
 			blockFuncs: []BlockFunc{
-				ArgumentsBlocker([][]string{
-					{"npm", "install", "-g"},
-					{"npm", "install", "--global"},
-				}),
+				ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
+				ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
 			},
 			command:     "npm install typescript",
 			shouldBlock: false,
@@ -116,3 +110,232 @@ func TestCommandBlocking(t *testing.T) {
 		})
 	}
 }
+
+func TestArgumentsBlocker(t *testing.T) {
+	tests := []struct {
+		name        string
+		cmd         string
+		args        []string
+		flags       []string
+		input       []string
+		shouldBlock bool
+	}{
+		// Basic command blocking
+		{
+			name:        "block exact command match",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       nil,
+			input:       []string{"npm", "install", "package"},
+			shouldBlock: true,
+		},
+		{
+			name:        "allow different command",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       nil,
+			input:       []string{"yarn", "install", "package"},
+			shouldBlock: false,
+		},
+		{
+			name:        "allow different subcommand",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       nil,
+			input:       []string{"npm", "list"},
+			shouldBlock: false,
+		},
+
+		// Flag-based blocking
+		{
+			name:        "block with single flag",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       []string{"-g"},
+			input:       []string{"npm", "install", "-g", "typescript"},
+			shouldBlock: true,
+		},
+		{
+			name:        "block with flag in different position",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       []string{"-g"},
+			input:       []string{"npm", "install", "typescript", "-g"},
+			shouldBlock: true,
+		},
+		{
+			name:        "allow without required flag",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       []string{"-g"},
+			input:       []string{"npm", "install", "typescript"},
+			shouldBlock: false,
+		},
+		{
+			name:        "block with multiple flags",
+			cmd:         "pip",
+			args:        []string{"install"},
+			flags:       []string{"--user"},
+			input:       []string{"pip", "install", "--user", "--upgrade", "package"},
+			shouldBlock: true,
+		},
+
+		// Complex argument patterns
+		{
+			name:        "block multi-arg subcommand",
+			cmd:         "yarn",
+			args:        []string{"global", "add"},
+			flags:       nil,
+			input:       []string{"yarn", "global", "add", "typescript"},
+			shouldBlock: true,
+		},
+		{
+			name:        "allow partial multi-arg match",
+			cmd:         "yarn",
+			args:        []string{"global", "add"},
+			flags:       nil,
+			input:       []string{"yarn", "global", "list"},
+			shouldBlock: false,
+		},
+
+		// Edge cases
+		{
+			name:        "handle empty input",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       nil,
+			input:       []string{},
+			shouldBlock: false,
+		},
+		{
+			name:        "handle command only",
+			cmd:         "npm",
+			args:        []string{"install"},
+			flags:       nil,
+			input:       []string{"npm"},
+			shouldBlock: false,
+		},
+		{
+			name:        "block pacman with -S flag",
+			cmd:         "pacman",
+			args:        nil,
+			flags:       []string{"-S"},
+			input:       []string{"pacman", "-S", "package"},
+			shouldBlock: true,
+		},
+		{
+			name:        "allow pacman without -S flag",
+			cmd:         "pacman",
+			args:        nil,
+			flags:       []string{"-S"},
+			input:       []string{"pacman", "-Q", "package"},
+			shouldBlock: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			blocker := ArgumentsBlocker(tt.cmd, tt.args, tt.flags)
+			result := blocker(tt.input)
+			require.Equal(t, tt.shouldBlock, result,
+				"Expected block=%v for input %v", tt.shouldBlock, tt.input)
+		})
+	}
+}
+
+func TestCommandsBlocker(t *testing.T) {
+	tests := []struct {
+		name        string
+		banned      []string
+		input       []string
+		shouldBlock bool
+	}{
+		{
+			name:        "block single banned command",
+			banned:      []string{"curl"},
+			input:       []string{"curl", "https://example.com"},
+			shouldBlock: true,
+		},
+		{
+			name:        "allow non-banned command",
+			banned:      []string{"curl", "wget"},
+			input:       []string{"echo", "hello"},
+			shouldBlock: false,
+		},
+		{
+			name:        "block from multiple banned",
+			banned:      []string{"curl", "wget", "nc"},
+			input:       []string{"wget", "https://example.com"},
+			shouldBlock: true,
+		},
+		{
+			name:        "handle empty input",
+			banned:      []string{"curl"},
+			input:       []string{},
+			shouldBlock: false,
+		},
+		{
+			name:        "case sensitive matching",
+			banned:      []string{"curl"},
+			input:       []string{"CURL", "https://example.com"},
+			shouldBlock: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			blocker := CommandsBlocker(tt.banned)
+			result := blocker(tt.input)
+			require.Equal(t, tt.shouldBlock, result,
+				"Expected block=%v for input %v", tt.shouldBlock, tt.input)
+		})
+	}
+}
+
+func TestSplitArgsFlags(t *testing.T) {
+	tests := []struct {
+		name      string
+		input     []string
+		wantArgs  []string
+		wantFlags []string
+	}{
+		{
+			name:      "only args",
+			input:     []string{"install", "package", "another"},
+			wantArgs:  []string{"install", "package", "another"},
+			wantFlags: []string{},
+		},
+		{
+			name:      "only flags",
+			input:     []string{"-g", "--verbose", "-f"},
+			wantArgs:  []string{},
+			wantFlags: []string{"-g", "--verbose", "-f"},
+		},
+		{
+			name:      "mixed args and flags",
+			input:     []string{"install", "-g", "package", "--verbose"},
+			wantArgs:  []string{"install", "package"},
+			wantFlags: []string{"-g", "--verbose"},
+		},
+		{
+			name:      "empty input",
+			input:     []string{},
+			wantArgs:  []string{},
+			wantFlags: []string{},
+		},
+		{
+			name:      "single dash flag",
+			input:     []string{"-S", "package"},
+			wantArgs:  []string{"package"},
+			wantFlags: []string{"-S"},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			args, flags := splitArgsFlags(tt.input)
+			require.Equal(t, tt.wantArgs, args, "args mismatch")
+			require.Equal(t, tt.wantFlags, flags, "flags mismatch")
+		})
+	}
+}

internal/shell/shell.go 🔗

@@ -16,9 +16,11 @@ import (
 	"errors"
 	"fmt"
 	"os"
+	"slices"
 	"strings"
 	"sync"
 
+	"github.com/charmbracelet/crush/internal/slicesext"
 	"mvdan.cc/sh/moreinterp/coreutils"
 	"mvdan.cc/sh/v3/expand"
 	"mvdan.cc/sh/v3/interp"
@@ -171,25 +173,36 @@ func CommandsBlocker(cmds []string) BlockFunc {
 	}
 }
 
-// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
-func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
-	return func(args []string) bool {
-		for _, blocked := range blockedSubCommands {
-			if len(args) >= len(blocked) {
-				match := true
-				for i, part := range blocked {
-					if args[i] != part {
-						match = false
-						break
-					}
-				}
-				if match {
-					return true
-				}
-			}
+// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
+func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
+	return func(parts []string) bool {
+		if len(parts) == 0 || parts[0] != cmd {
+			return false
+		}
+
+		argParts, flagParts := splitArgsFlags(parts[1:])
+		if len(argParts) < len(args) || len(flagParts) < len(flags) {
+			return false
+		}
+
+		argsMatch := slices.Equal(argParts[:len(args)], args)
+		flagsMatch := slicesext.IsSubset(flags, flagParts)
+
+		return argsMatch && flagsMatch
+	}
+}
+
+func splitArgsFlags(parts []string) (args []string, flags []string) {
+	args = make([]string, 0, len(parts))
+	flags = make([]string, 0, len(parts))
+	for _, part := range parts {
+		if strings.HasPrefix(part, "-") {
+			flags = append(flags, part)
+		} else {
+			args = append(args, part)
 		}
-		return false
 	}
+	return
 }
 
 func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {