pattern_extraction.rs

  1use shell_command_parser::extract_terminal_command_prefix;
  2use std::path::{Path, PathBuf};
  3use url::Url;
  4
  5/// Normalize path separators to forward slashes for consistent cross-platform patterns.
  6fn normalize_separators(path_str: &str) -> String {
  7    path_str.replace('\\', "/")
  8}
  9
 10/// Returns true if the token looks like a command name or subcommand — i.e. it
 11/// contains only alphanumeric characters, hyphens, and underscores, and does not
 12/// start with a hyphen (which would make it a flag).
 13fn is_plain_command_token(token: &str) -> bool {
 14    !token.starts_with('-')
 15        && token
 16            .chars()
 17            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
 18}
 19
 20struct CommandPrefix {
 21    normalized_tokens: Vec<String>,
 22    display: String,
 23}
 24
 25/// Extracts the command name and optional subcommand from a shell command using
 26/// the shell parser.
 27///
 28/// This parses the command properly to extract the command name and optional
 29/// subcommand (e.g. "cargo" and "test" from "cargo test -p search"), handling shell
 30/// syntax correctly. Returns `None` if parsing fails or if the command name
 31/// contains path separators (for security reasons).
 32fn extract_command_prefix(command: &str) -> Option<CommandPrefix> {
 33    let prefix = extract_terminal_command_prefix(command)?;
 34
 35    if !is_plain_command_token(&prefix.command) {
 36        return None;
 37    }
 38
 39    Some(CommandPrefix {
 40        normalized_tokens: prefix.tokens,
 41        display: prefix.display,
 42    })
 43}
 44
 45/// Extracts a regex pattern from a terminal command based on the first token (command name).
 46///
 47/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
 48/// This is a deliberate security decision: we only allow pattern-based "always allow"
 49/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
 50/// scripts or absolute paths which could be manipulated by an attacker.
 51pub fn extract_terminal_pattern(command: &str) -> Option<String> {
 52    let prefix = extract_command_prefix(command)?;
 53    let tokens = prefix.normalized_tokens;
 54
 55    match tokens.as_slice() {
 56        [] => None,
 57        [single] => Some(format!("^{}\\b", regex::escape(single))),
 58        [rest @ .., last] => Some(format!(
 59            "^{}\\s+{}(\\s|$)",
 60            rest.iter()
 61                .map(|token| regex::escape(token))
 62                .collect::<Vec<_>>()
 63                .join("\\s+"),
 64            regex::escape(last)
 65        )),
 66    }
 67}
 68
 69pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
 70    let prefix = extract_command_prefix(command)?;
 71    Some(prefix.display)
 72}
 73
 74pub fn extract_path_pattern(path: &str) -> Option<String> {
 75    let parent = Path::new(path).parent()?;
 76    let parent_str = normalize_separators(parent.to_str()?);
 77    if parent_str.is_empty() || parent_str == "/" {
 78        return None;
 79    }
 80    Some(format!("^{}/", regex::escape(&parent_str)))
 81}
 82
 83pub fn extract_path_pattern_display(path: &str) -> Option<String> {
 84    let parent = Path::new(path).parent()?;
 85    let parent_str = normalize_separators(parent.to_str()?);
 86    if parent_str.is_empty() || parent_str == "/" {
 87        return None;
 88    }
 89    Some(format!("{}/", parent_str))
 90}
 91
 92fn common_parent_dir(path_a: &str, path_b: &str) -> Option<PathBuf> {
 93    let parent_a = Path::new(path_a).parent()?;
 94    let parent_b = Path::new(path_b).parent()?;
 95
 96    let components_a: Vec<_> = parent_a.components().collect();
 97    let components_b: Vec<_> = parent_b.components().collect();
 98
 99    let common_count = components_a
100        .iter()
101        .zip(components_b.iter())
102        .take_while(|(a, b)| a == b)
103        .count();
104
105    if common_count == 0 {
106        return None;
107    }
108
109    let common: PathBuf = components_a[..common_count].iter().collect();
110    Some(common)
111}
112
113pub fn extract_copy_move_pattern(input: &str) -> Option<String> {
114    let (source, dest) = input.split_once('\n')?;
115    let common = common_parent_dir(source, dest)?;
116    let common_str = normalize_separators(common.to_str()?);
117    if common_str.is_empty() || common_str == "/" {
118        return None;
119    }
120    Some(format!("^{}/", regex::escape(&common_str)))
121}
122
123pub fn extract_copy_move_pattern_display(input: &str) -> Option<String> {
124    let (source, dest) = input.split_once('\n')?;
125    let common = common_parent_dir(source, dest)?;
126    let common_str = normalize_separators(common.to_str()?);
127    if common_str.is_empty() || common_str == "/" {
128        return None;
129    }
130    Some(format!("{}/", common_str))
131}
132
133pub fn extract_url_pattern(url: &str) -> Option<String> {
134    let parsed = Url::parse(url).ok()?;
135    let domain = parsed.host_str()?;
136    Some(format!("^https?://{}", regex::escape(domain)))
137}
138
139pub fn extract_url_pattern_display(url: &str) -> Option<String> {
140    let parsed = Url::parse(url).ok()?;
141    let domain = parsed.host_str()?;
142    Some(domain.to_string())
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_extract_terminal_pattern() {
151        assert_eq!(
152            extract_terminal_pattern("cargo build --release"),
153            Some("^cargo\\s+build(\\s|$)".to_string())
154        );
155        assert_eq!(
156            extract_terminal_pattern("cargo test -p search"),
157            Some("^cargo\\s+test(\\s|$)".to_string())
158        );
159        assert_eq!(
160            extract_terminal_pattern("npm install"),
161            Some("^npm\\s+install(\\s|$)".to_string())
162        );
163        assert_eq!(
164            extract_terminal_pattern("git-lfs pull"),
165            Some("^git\\-lfs\\s+pull(\\s|$)".to_string())
166        );
167        assert_eq!(
168            extract_terminal_pattern("my_script arg"),
169            Some("^my_script\\s+arg(\\s|$)".to_string())
170        );
171
172        // Flags as second token: only the command name is used
173        assert_eq!(
174            extract_terminal_pattern("ls -la"),
175            Some("^ls\\b".to_string())
176        );
177        assert_eq!(
178            extract_terminal_pattern("rm --force foo"),
179            Some("^rm\\b".to_string())
180        );
181
182        // Single-word commands
183        assert_eq!(extract_terminal_pattern("ls"), Some("^ls\\b".to_string()));
184
185        // Subcommand pattern does not match a hyphenated extension of the subcommand
186        // (e.g. approving "cargo build" should not approve "cargo build-foo")
187        assert_eq!(
188            extract_terminal_pattern("cargo build"),
189            Some("^cargo\\s+build(\\s|$)".to_string())
190        );
191        let pattern = regex::Regex::new(&extract_terminal_pattern("cargo build").unwrap()).unwrap();
192        assert!(pattern.is_match("cargo build --release"));
193        assert!(pattern.is_match("cargo build"));
194        assert!(!pattern.is_match("cargo build-foo"));
195        assert!(!pattern.is_match("cargo builder"));
196
197        // Env-var prefixes are included in generated patterns
198        assert_eq!(
199            extract_terminal_pattern("PAGER=blah git log --oneline"),
200            Some("^PAGER=blah\\s+git\\s+log(\\s|$)".to_string())
201        );
202        assert_eq!(
203            extract_terminal_pattern("A=1 B=2 git log"),
204            Some("^A=1\\s+B=2\\s+git\\s+log(\\s|$)".to_string())
205        );
206        assert_eq!(
207            extract_terminal_pattern("PAGER='less -R' git log"),
208            Some("^PAGER='less \\-R'\\s+git\\s+log(\\s|$)".to_string())
209        );
210
211        // Path-like commands are rejected
212        assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
213        assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
214        assert_eq!(extract_terminal_pattern("PAGER=blah ./script.sh arg"), None);
215    }
216
217    #[test]
218    fn test_extract_terminal_pattern_display() {
219        assert_eq!(
220            extract_terminal_pattern_display("cargo build --release"),
221            Some("cargo build".to_string())
222        );
223        assert_eq!(
224            extract_terminal_pattern_display("cargo test -p search"),
225            Some("cargo test".to_string())
226        );
227        assert_eq!(
228            extract_terminal_pattern_display("npm install"),
229            Some("npm install".to_string())
230        );
231        assert_eq!(
232            extract_terminal_pattern_display("ls -la"),
233            Some("ls".to_string())
234        );
235        assert_eq!(
236            extract_terminal_pattern_display("ls"),
237            Some("ls".to_string())
238        );
239        assert_eq!(
240            extract_terminal_pattern_display("PAGER=blah   git   log --oneline"),
241            Some("PAGER=blah   git   log".to_string())
242        );
243        assert_eq!(
244            extract_terminal_pattern_display("PAGER='less -R' git log"),
245            Some("PAGER='less -R' git log".to_string())
246        );
247    }
248
249    #[test]
250    fn test_terminal_pattern_regex_normalizes_whitespace() {
251        let pattern = extract_terminal_pattern("PAGER=blah   git   log --oneline")
252            .expect("expected terminal pattern");
253        let regex = regex::Regex::new(&pattern).expect("expected valid regex");
254
255        assert!(regex.is_match("PAGER=blah git log"));
256        assert!(regex.is_match("PAGER=blah    git    log --stat"));
257    }
258
259    #[test]
260    fn test_extract_terminal_pattern_skips_redirects_before_subcommand() {
261        assert_eq!(
262            extract_terminal_pattern("git 2>/dev/null log --oneline"),
263            Some("^git\\s+log(\\s|$)".to_string())
264        );
265        assert_eq!(
266            extract_terminal_pattern_display("git 2>/dev/null log --oneline"),
267            Some("git 2>/dev/null log".to_string())
268        );
269
270        assert_eq!(
271            extract_terminal_pattern("rm --force foo"),
272            Some("^rm\\b".to_string())
273        );
274    }
275
276    #[test]
277    fn test_extract_path_pattern() {
278        assert_eq!(
279            extract_path_pattern("/Users/alice/project/src/main.rs"),
280            Some("^/Users/alice/project/src/".to_string())
281        );
282        assert_eq!(
283            extract_path_pattern("src/lib.rs"),
284            Some("^src/".to_string())
285        );
286        assert_eq!(extract_path_pattern("file.txt"), None);
287        assert_eq!(extract_path_pattern("/file.txt"), None);
288    }
289
290    #[test]
291    fn test_extract_path_pattern_display() {
292        assert_eq!(
293            extract_path_pattern_display("/Users/alice/project/src/main.rs"),
294            Some("/Users/alice/project/src/".to_string())
295        );
296        assert_eq!(
297            extract_path_pattern_display("src/lib.rs"),
298            Some("src/".to_string())
299        );
300    }
301
302    #[test]
303    fn test_extract_url_pattern() {
304        assert_eq!(
305            extract_url_pattern("https://github.com/user/repo"),
306            Some("^https?://github\\.com".to_string())
307        );
308        assert_eq!(
309            extract_url_pattern("http://example.com/path?query=1"),
310            Some("^https?://example\\.com".to_string())
311        );
312        assert_eq!(extract_url_pattern("not a url"), None);
313    }
314
315    #[test]
316    fn test_extract_url_pattern_display() {
317        assert_eq!(
318            extract_url_pattern_display("https://github.com/user/repo"),
319            Some("github.com".to_string())
320        );
321        assert_eq!(
322            extract_url_pattern_display("http://api.example.com/v1/users"),
323            Some("api.example.com".to_string())
324        );
325    }
326
327    #[test]
328    fn test_special_chars_are_escaped() {
329        assert_eq!(
330            extract_path_pattern("/path/with (parens)/file.txt"),
331            Some("^/path/with \\(parens\\)/".to_string())
332        );
333        assert_eq!(
334            extract_url_pattern("https://test.example.com/path"),
335            Some("^https?://test\\.example\\.com".to_string())
336        );
337    }
338
339    #[test]
340    fn test_extract_copy_move_pattern_same_directory() {
341        assert_eq!(
342            extract_copy_move_pattern(
343                "/Users/alice/project/src/old.rs\n/Users/alice/project/src/new.rs"
344            ),
345            Some("^/Users/alice/project/src/".to_string())
346        );
347    }
348
349    #[test]
350    fn test_extract_copy_move_pattern_sibling_directories() {
351        assert_eq!(
352            extract_copy_move_pattern(
353                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
354            ),
355            Some("^/Users/alice/project/".to_string())
356        );
357    }
358
359    #[test]
360    fn test_extract_copy_move_pattern_no_common_prefix() {
361        assert_eq!(
362            extract_copy_move_pattern("/home/file.txt\n/tmp/file.txt"),
363            None
364        );
365    }
366
367    #[test]
368    fn test_extract_copy_move_pattern_relative_paths() {
369        assert_eq!(
370            extract_copy_move_pattern("src/old.rs\nsrc/new.rs"),
371            Some("^src/".to_string())
372        );
373    }
374
375    #[test]
376    fn test_extract_copy_move_pattern_display() {
377        assert_eq!(
378            extract_copy_move_pattern_display(
379                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
380            ),
381            Some("/Users/alice/project/".to_string())
382        );
383    }
384
385    #[test]
386    fn test_extract_copy_move_pattern_no_arrow() {
387        assert_eq!(extract_copy_move_pattern("just/a/path.rs"), None);
388        assert_eq!(extract_copy_move_pattern_display("just/a/path.rs"), None);
389    }
390}