pattern_extraction.rs

  1use acp_thread::PermissionPattern;
  2use shell_command_parser::{extract_commands, extract_terminal_command_prefix};
  3use std::path::{Path, PathBuf};
  4use url::Url;
  5
  6/// Normalize path separators to forward slashes for consistent cross-platform patterns.
  7fn normalize_separators(path_str: &str) -> String {
  8    path_str.replace('\\', "/")
  9}
 10
 11/// Returns true if the token looks like a command name or subcommand — i.e. it
 12/// contains only alphanumeric characters, hyphens, and underscores, and does not
 13/// start with a hyphen (which would make it a flag).
 14fn is_plain_command_token(token: &str) -> bool {
 15    !token.starts_with('-')
 16        && token
 17            .chars()
 18            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
 19}
 20
 21struct CommandPrefix {
 22    normalized_tokens: Vec<String>,
 23    display: String,
 24}
 25
 26/// Extracts the command name and optional subcommand from a shell command using
 27/// the shell parser.
 28///
 29/// This parses the command properly to extract the command name and optional
 30/// subcommand (e.g. "cargo" and "test" from "cargo test -p search"), handling shell
 31/// syntax correctly. Returns `None` if parsing fails or if the command name
 32/// contains path separators (for security reasons).
 33fn extract_command_prefix(command: &str) -> Option<CommandPrefix> {
 34    let prefix = extract_terminal_command_prefix(command)?;
 35
 36    if !is_plain_command_token(&prefix.command) {
 37        return None;
 38    }
 39
 40    Some(CommandPrefix {
 41        normalized_tokens: prefix.tokens,
 42        display: prefix.display,
 43    })
 44}
 45
 46/// Extracts a regex pattern and display name from a terminal command.
 47///
 48/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
 49/// This is a deliberate security decision: we only allow pattern-based "always allow"
 50/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
 51/// scripts or absolute paths which could be manipulated by an attacker.
 52pub fn extract_terminal_permission_pattern(command: &str) -> Option<PermissionPattern> {
 53    let pattern = extract_terminal_pattern(command)?;
 54    let display_name = extract_terminal_pattern_display(command)?;
 55    Some(PermissionPattern {
 56        pattern,
 57        display_name,
 58    })
 59}
 60
 61pub fn extract_terminal_pattern(command: &str) -> Option<String> {
 62    let prefix = extract_command_prefix(command)?;
 63    let tokens = prefix.normalized_tokens;
 64
 65    match tokens.as_slice() {
 66        [] => None,
 67        [single] => Some(format!("^{}\\b", regex::escape(single))),
 68        [rest @ .., last] => Some(format!(
 69            "^{}\\s+{}(\\s|$)",
 70            rest.iter()
 71                .map(|token| regex::escape(token))
 72                .collect::<Vec<_>>()
 73                .join("\\s+"),
 74            regex::escape(last)
 75        )),
 76    }
 77}
 78
 79pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
 80    let prefix = extract_command_prefix(command)?;
 81    Some(prefix.display)
 82}
 83
 84/// Extracts patterns for ALL commands in a pipeline, not just the first one.
 85///
 86/// For a command like `"cargo test 2>&1 | tail"`, this returns patterns for
 87/// both `cargo` and `tail`. Path-based commands (e.g. `./script.sh`) are
 88/// filtered out, and duplicate command names are deduplicated while preserving
 89/// order.
 90pub fn extract_all_terminal_patterns(command: &str) -> Vec<PermissionPattern> {
 91    let commands = match extract_commands(command) {
 92        Some(commands) => commands,
 93        None => return Vec::new(),
 94    };
 95
 96    let mut results = Vec::new();
 97
 98    for cmd in &commands {
 99        let Some(permission_pattern) = extract_terminal_permission_pattern(cmd) else {
100            continue;
101        };
102
103        if results.contains(&permission_pattern) {
104            continue;
105        }
106
107        results.push(permission_pattern);
108    }
109
110    results
111}
112
113pub fn extract_path_pattern(path: &str) -> Option<String> {
114    let parent = Path::new(path).parent()?;
115    let parent_str = normalize_separators(parent.to_str()?);
116    if parent_str.is_empty() || parent_str == "/" {
117        return None;
118    }
119    Some(format!("^{}/", regex::escape(&parent_str)))
120}
121
122pub fn extract_path_pattern_display(path: &str) -> Option<String> {
123    let parent = Path::new(path).parent()?;
124    let parent_str = normalize_separators(parent.to_str()?);
125    if parent_str.is_empty() || parent_str == "/" {
126        return None;
127    }
128    Some(format!("{}/", parent_str))
129}
130
131fn common_parent_dir(path_a: &str, path_b: &str) -> Option<PathBuf> {
132    let parent_a = Path::new(path_a).parent()?;
133    let parent_b = Path::new(path_b).parent()?;
134
135    let components_a: Vec<_> = parent_a.components().collect();
136    let components_b: Vec<_> = parent_b.components().collect();
137
138    let common_count = components_a
139        .iter()
140        .zip(components_b.iter())
141        .take_while(|(a, b)| a == b)
142        .count();
143
144    if common_count == 0 {
145        return None;
146    }
147
148    let common: PathBuf = components_a[..common_count].iter().collect();
149    Some(common)
150}
151
152pub fn extract_copy_move_pattern(input: &str) -> Option<String> {
153    let (source, dest) = input.split_once('\n')?;
154    let common = common_parent_dir(source, dest)?;
155    let common_str = normalize_separators(common.to_str()?);
156    if common_str.is_empty() || common_str == "/" {
157        return None;
158    }
159    Some(format!("^{}/", regex::escape(&common_str)))
160}
161
162pub fn extract_copy_move_pattern_display(input: &str) -> Option<String> {
163    let (source, dest) = input.split_once('\n')?;
164    let common = common_parent_dir(source, dest)?;
165    let common_str = normalize_separators(common.to_str()?);
166    if common_str.is_empty() || common_str == "/" {
167        return None;
168    }
169    Some(format!("{}/", common_str))
170}
171
172pub fn extract_url_pattern(url: &str) -> Option<String> {
173    let parsed = Url::parse(url).ok()?;
174    let domain = parsed.host_str()?;
175    Some(format!("^https?://{}", regex::escape(domain)))
176}
177
178pub fn extract_url_pattern_display(url: &str) -> Option<String> {
179    let parsed = Url::parse(url).ok()?;
180    let domain = parsed.host_str()?;
181    Some(domain.to_string())
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_extract_terminal_pattern() {
190        assert_eq!(
191            extract_terminal_pattern("cargo build --release"),
192            Some("^cargo\\s+build(\\s|$)".to_string())
193        );
194        assert_eq!(
195            extract_terminal_pattern("cargo test -p search"),
196            Some("^cargo\\s+test(\\s|$)".to_string())
197        );
198        assert_eq!(
199            extract_terminal_pattern("npm install"),
200            Some("^npm\\s+install(\\s|$)".to_string())
201        );
202        assert_eq!(
203            extract_terminal_pattern("git-lfs pull"),
204            Some("^git\\-lfs\\s+pull(\\s|$)".to_string())
205        );
206        assert_eq!(
207            extract_terminal_pattern("my_script arg"),
208            Some("^my_script\\s+arg(\\s|$)".to_string())
209        );
210
211        // Flags as second token: only the command name is used
212        assert_eq!(
213            extract_terminal_pattern("ls -la"),
214            Some("^ls\\b".to_string())
215        );
216        assert_eq!(
217            extract_terminal_pattern("rm --force foo"),
218            Some("^rm\\b".to_string())
219        );
220
221        // Single-word commands
222        assert_eq!(extract_terminal_pattern("ls"), Some("^ls\\b".to_string()));
223
224        // Subcommand pattern does not match a hyphenated extension of the subcommand
225        // (e.g. approving "cargo build" should not approve "cargo build-foo")
226        assert_eq!(
227            extract_terminal_pattern("cargo build"),
228            Some("^cargo\\s+build(\\s|$)".to_string())
229        );
230        let pattern = regex::Regex::new(&extract_terminal_pattern("cargo build").unwrap()).unwrap();
231        assert!(pattern.is_match("cargo build --release"));
232        assert!(pattern.is_match("cargo build"));
233        assert!(!pattern.is_match("cargo build-foo"));
234        assert!(!pattern.is_match("cargo builder"));
235
236        // Env-var prefixes are included in generated patterns
237        assert_eq!(
238            extract_terminal_pattern("PAGER=blah git log --oneline"),
239            Some("^PAGER=blah\\s+git\\s+log(\\s|$)".to_string())
240        );
241        assert_eq!(
242            extract_terminal_pattern("A=1 B=2 git log"),
243            Some("^A=1\\s+B=2\\s+git\\s+log(\\s|$)".to_string())
244        );
245        assert_eq!(
246            extract_terminal_pattern("PAGER='less -R' git log"),
247            Some("^PAGER='less \\-R'\\s+git\\s+log(\\s|$)".to_string())
248        );
249
250        // Path-like commands are rejected
251        assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
252        assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
253        assert_eq!(extract_terminal_pattern("PAGER=blah ./script.sh arg"), None);
254    }
255
256    #[test]
257    fn test_extract_terminal_pattern_display() {
258        assert_eq!(
259            extract_terminal_pattern_display("cargo build --release"),
260            Some("cargo build".to_string())
261        );
262        assert_eq!(
263            extract_terminal_pattern_display("cargo test -p search"),
264            Some("cargo test".to_string())
265        );
266        assert_eq!(
267            extract_terminal_pattern_display("npm install"),
268            Some("npm install".to_string())
269        );
270        assert_eq!(
271            extract_terminal_pattern_display("ls -la"),
272            Some("ls".to_string())
273        );
274        assert_eq!(
275            extract_terminal_pattern_display("ls"),
276            Some("ls".to_string())
277        );
278        assert_eq!(
279            extract_terminal_pattern_display("PAGER=blah   git   log --oneline"),
280            Some("PAGER=blah   git   log".to_string())
281        );
282        assert_eq!(
283            extract_terminal_pattern_display("PAGER='less -R' git log"),
284            Some("PAGER='less -R' git log".to_string())
285        );
286    }
287
288    #[test]
289    fn test_terminal_pattern_regex_normalizes_whitespace() {
290        let pattern = extract_terminal_pattern("PAGER=blah   git   log --oneline")
291            .expect("expected terminal pattern");
292        let regex = regex::Regex::new(&pattern).expect("expected valid regex");
293
294        assert!(regex.is_match("PAGER=blah git log"));
295        assert!(regex.is_match("PAGER=blah    git    log --stat"));
296    }
297
298    #[test]
299    fn test_extract_terminal_pattern_skips_redirects_before_subcommand() {
300        assert_eq!(
301            extract_terminal_pattern("git 2>/dev/null log --oneline"),
302            Some("^git\\s+log(\\s|$)".to_string())
303        );
304        assert_eq!(
305            extract_terminal_pattern_display("git 2>/dev/null log --oneline"),
306            Some("git 2>/dev/null log".to_string())
307        );
308
309        assert_eq!(
310            extract_terminal_pattern("rm --force foo"),
311            Some("^rm\\b".to_string())
312        );
313    }
314
315    #[test]
316    fn test_extract_all_terminal_patterns_pipeline() {
317        assert_eq!(
318            extract_all_terminal_patterns("cargo test 2>&1 | tail"),
319            vec![
320                PermissionPattern {
321                    pattern: "^cargo\\s+test(\\s|$)".to_string(),
322                    display_name: "cargo test".to_string(),
323                },
324                PermissionPattern {
325                    pattern: "^tail\\b".to_string(),
326                    display_name: "tail".to_string(),
327                },
328            ]
329        );
330    }
331
332    #[test]
333    fn test_extract_all_terminal_patterns_with_path_commands() {
334        assert_eq!(
335            extract_all_terminal_patterns("./script.sh | grep foo"),
336            vec![PermissionPattern {
337                pattern: "^grep\\s+foo(\\s|$)".to_string(),
338                display_name: "grep foo".to_string(),
339            }]
340        );
341    }
342
343    #[test]
344    fn test_extract_all_terminal_patterns_all_paths() {
345        assert_eq!(extract_all_terminal_patterns("./a.sh | /usr/bin/b"), vec![]);
346    }
347
348    #[test]
349    fn test_extract_path_pattern() {
350        assert_eq!(
351            extract_path_pattern("/Users/alice/project/src/main.rs"),
352            Some("^/Users/alice/project/src/".to_string())
353        );
354        assert_eq!(
355            extract_path_pattern("src/lib.rs"),
356            Some("^src/".to_string())
357        );
358        assert_eq!(extract_path_pattern("file.txt"), None);
359        assert_eq!(extract_path_pattern("/file.txt"), None);
360    }
361
362    #[test]
363    fn test_extract_path_pattern_display() {
364        assert_eq!(
365            extract_path_pattern_display("/Users/alice/project/src/main.rs"),
366            Some("/Users/alice/project/src/".to_string())
367        );
368        assert_eq!(
369            extract_path_pattern_display("src/lib.rs"),
370            Some("src/".to_string())
371        );
372    }
373
374    #[test]
375    fn test_extract_url_pattern() {
376        assert_eq!(
377            extract_url_pattern("https://github.com/user/repo"),
378            Some("^https?://github\\.com".to_string())
379        );
380        assert_eq!(
381            extract_url_pattern("http://example.com/path?query=1"),
382            Some("^https?://example\\.com".to_string())
383        );
384        assert_eq!(extract_url_pattern("not a url"), None);
385    }
386
387    #[test]
388    fn test_extract_url_pattern_display() {
389        assert_eq!(
390            extract_url_pattern_display("https://github.com/user/repo"),
391            Some("github.com".to_string())
392        );
393        assert_eq!(
394            extract_url_pattern_display("http://api.example.com/v1/users"),
395            Some("api.example.com".to_string())
396        );
397    }
398
399    #[test]
400    fn test_special_chars_are_escaped() {
401        assert_eq!(
402            extract_path_pattern("/path/with (parens)/file.txt"),
403            Some("^/path/with \\(parens\\)/".to_string())
404        );
405        assert_eq!(
406            extract_url_pattern("https://test.example.com/path"),
407            Some("^https?://test\\.example\\.com".to_string())
408        );
409    }
410
411    #[test]
412    fn test_extract_copy_move_pattern_same_directory() {
413        assert_eq!(
414            extract_copy_move_pattern(
415                "/Users/alice/project/src/old.rs\n/Users/alice/project/src/new.rs"
416            ),
417            Some("^/Users/alice/project/src/".to_string())
418        );
419    }
420
421    #[test]
422    fn test_extract_copy_move_pattern_sibling_directories() {
423        assert_eq!(
424            extract_copy_move_pattern(
425                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
426            ),
427            Some("^/Users/alice/project/".to_string())
428        );
429    }
430
431    #[test]
432    fn test_extract_copy_move_pattern_no_common_prefix() {
433        assert_eq!(
434            extract_copy_move_pattern("/home/file.txt\n/tmp/file.txt"),
435            None
436        );
437    }
438
439    #[test]
440    fn test_extract_copy_move_pattern_relative_paths() {
441        assert_eq!(
442            extract_copy_move_pattern("src/old.rs\nsrc/new.rs"),
443            Some("^src/".to_string())
444        );
445    }
446
447    #[test]
448    fn test_extract_copy_move_pattern_display() {
449        assert_eq!(
450            extract_copy_move_pattern_display(
451                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
452            ),
453            Some("/Users/alice/project/".to_string())
454        );
455    }
456
457    #[test]
458    fn test_extract_copy_move_pattern_no_arrow() {
459        assert_eq!(extract_copy_move_pattern("just/a/path.rs"), None);
460        assert_eq!(extract_copy_move_pattern_display("just/a/path.rs"), None);
461    }
462}