pattern_extraction.rs

  1use shell_command_parser::extract_commands;
  2use std::path::{Path, PathBuf};
  3use url::Url;
  4
  5/// Normalize path separators to forward slashes for consistent cross-platform patterns.
  6fn normalize_separators(path_str: &str) -> String {
  7    path_str.replace('\\', "/")
  8}
  9
 10/// Returns true if the token looks like a command name or subcommand — i.e. it
 11/// contains only alphanumeric characters, hyphens, and underscores, and does not
 12/// start with a hyphen (which would make it a flag).
 13fn is_plain_command_token(token: &str) -> bool {
 14    !token.starts_with('-')
 15        && token
 16            .chars()
 17            .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
 18}
 19
 20struct CommandPrefix {
 21    command: String,
 22    subcommand: Option<String>,
 23}
 24
 25/// Extracts the command name and optional subcommand from a shell command using
 26/// the shell parser.
 27///
 28/// This parses the command properly to extract the command name and optional
 29/// subcommand (e.g. "cargo" and "test" from "cargo test -p search"), handling shell
 30/// syntax correctly. Returns `None` if parsing fails or if the command name
 31/// contains path separators (for security reasons).
 32fn extract_command_prefix(command: &str) -> Option<CommandPrefix> {
 33    let commands = extract_commands(command)?;
 34    let first_command = commands.first()?;
 35
 36    let mut tokens = first_command.split_whitespace();
 37    let first_token = tokens.next()?;
 38
 39    // Only allow alphanumeric commands with hyphens/underscores.
 40    // Reject paths like "./script.sh" or "/usr/bin/python" to prevent
 41    // users from accidentally allowing arbitrary script execution.
 42    if !is_plain_command_token(first_token) {
 43        return None;
 44    }
 45
 46    // Include the subcommand (second non-flag token) when present, to produce
 47    // more specific patterns like "cargo test" instead of just "cargo".
 48    let subcommand = tokens
 49        .next()
 50        .filter(|second_token| is_plain_command_token(second_token))
 51        .map(|second_token| second_token.to_string());
 52
 53    Some(CommandPrefix {
 54        command: first_token.to_string(),
 55        subcommand,
 56    })
 57}
 58
 59/// Extracts a regex pattern from a terminal command based on the first token (command name).
 60///
 61/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
 62/// This is a deliberate security decision: we only allow pattern-based "always allow"
 63/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
 64/// scripts or absolute paths which could be manipulated by an attacker.
 65pub fn extract_terminal_pattern(command: &str) -> Option<String> {
 66    let prefix = extract_command_prefix(command)?;
 67    let escaped_command = regex::escape(&prefix.command);
 68    Some(match &prefix.subcommand {
 69        Some(subcommand) => {
 70            format!(
 71                "^{}\\s+{}(\\s|$)",
 72                escaped_command,
 73                regex::escape(subcommand)
 74            )
 75        }
 76        None => format!("^{}\\b", escaped_command),
 77    })
 78}
 79
 80pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
 81    let prefix = extract_command_prefix(command)?;
 82    match prefix.subcommand {
 83        Some(subcommand) => Some(format!("{} {}", prefix.command, subcommand)),
 84        None => Some(prefix.command),
 85    }
 86}
 87
 88pub fn extract_path_pattern(path: &str) -> Option<String> {
 89    let parent = Path::new(path).parent()?;
 90    let parent_str = normalize_separators(parent.to_str()?);
 91    if parent_str.is_empty() || parent_str == "/" {
 92        return None;
 93    }
 94    Some(format!("^{}/", regex::escape(&parent_str)))
 95}
 96
 97pub fn extract_path_pattern_display(path: &str) -> Option<String> {
 98    let parent = Path::new(path).parent()?;
 99    let parent_str = normalize_separators(parent.to_str()?);
100    if parent_str.is_empty() || parent_str == "/" {
101        return None;
102    }
103    Some(format!("{}/", parent_str))
104}
105
106fn common_parent_dir(path_a: &str, path_b: &str) -> Option<PathBuf> {
107    let parent_a = Path::new(path_a).parent()?;
108    let parent_b = Path::new(path_b).parent()?;
109
110    let components_a: Vec<_> = parent_a.components().collect();
111    let components_b: Vec<_> = parent_b.components().collect();
112
113    let common_count = components_a
114        .iter()
115        .zip(components_b.iter())
116        .take_while(|(a, b)| a == b)
117        .count();
118
119    if common_count == 0 {
120        return None;
121    }
122
123    let common: PathBuf = components_a[..common_count].iter().collect();
124    Some(common)
125}
126
127pub fn extract_copy_move_pattern(input: &str) -> Option<String> {
128    let (source, dest) = input.split_once('\n')?;
129    let common = common_parent_dir(source, dest)?;
130    let common_str = normalize_separators(common.to_str()?);
131    if common_str.is_empty() || common_str == "/" {
132        return None;
133    }
134    Some(format!("^{}/", regex::escape(&common_str)))
135}
136
137pub fn extract_copy_move_pattern_display(input: &str) -> Option<String> {
138    let (source, dest) = input.split_once('\n')?;
139    let common = common_parent_dir(source, dest)?;
140    let common_str = normalize_separators(common.to_str()?);
141    if common_str.is_empty() || common_str == "/" {
142        return None;
143    }
144    Some(format!("{}/", common_str))
145}
146
147pub fn extract_url_pattern(url: &str) -> Option<String> {
148    let parsed = Url::parse(url).ok()?;
149    let domain = parsed.host_str()?;
150    Some(format!("^https?://{}", regex::escape(domain)))
151}
152
153pub fn extract_url_pattern_display(url: &str) -> Option<String> {
154    let parsed = Url::parse(url).ok()?;
155    let domain = parsed.host_str()?;
156    Some(domain.to_string())
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_extract_terminal_pattern() {
165        assert_eq!(
166            extract_terminal_pattern("cargo build --release"),
167            Some("^cargo\\s+build(\\s|$)".to_string())
168        );
169        assert_eq!(
170            extract_terminal_pattern("cargo test -p search"),
171            Some("^cargo\\s+test(\\s|$)".to_string())
172        );
173        assert_eq!(
174            extract_terminal_pattern("npm install"),
175            Some("^npm\\s+install(\\s|$)".to_string())
176        );
177        assert_eq!(
178            extract_terminal_pattern("git-lfs pull"),
179            Some("^git\\-lfs\\s+pull(\\s|$)".to_string())
180        );
181        assert_eq!(
182            extract_terminal_pattern("my_script arg"),
183            Some("^my_script\\s+arg(\\s|$)".to_string())
184        );
185
186        // Flags as second token: only the command name is used
187        assert_eq!(
188            extract_terminal_pattern("ls -la"),
189            Some("^ls\\b".to_string())
190        );
191        assert_eq!(
192            extract_terminal_pattern("rm --force foo"),
193            Some("^rm\\b".to_string())
194        );
195
196        // Single-word commands
197        assert_eq!(extract_terminal_pattern("ls"), Some("^ls\\b".to_string()));
198
199        // Subcommand pattern does not match a hyphenated extension of the subcommand
200        // (e.g. approving "cargo build" should not approve "cargo build-foo")
201        assert_eq!(
202            extract_terminal_pattern("cargo build"),
203            Some("^cargo\\s+build(\\s|$)".to_string())
204        );
205        let pattern = regex::Regex::new(&extract_terminal_pattern("cargo build").unwrap()).unwrap();
206        assert!(pattern.is_match("cargo build --release"));
207        assert!(pattern.is_match("cargo build"));
208        assert!(!pattern.is_match("cargo build-foo"));
209        assert!(!pattern.is_match("cargo builder"));
210
211        // Path-like commands are rejected
212        assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
213        assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
214    }
215
216    #[test]
217    fn test_extract_terminal_pattern_display() {
218        assert_eq!(
219            extract_terminal_pattern_display("cargo build --release"),
220            Some("cargo build".to_string())
221        );
222        assert_eq!(
223            extract_terminal_pattern_display("cargo test -p search"),
224            Some("cargo test".to_string())
225        );
226        assert_eq!(
227            extract_terminal_pattern_display("npm install"),
228            Some("npm install".to_string())
229        );
230        assert_eq!(
231            extract_terminal_pattern_display("ls -la"),
232            Some("ls".to_string())
233        );
234        assert_eq!(
235            extract_terminal_pattern_display("ls"),
236            Some("ls".to_string())
237        );
238    }
239
240    #[test]
241    fn test_extract_path_pattern() {
242        assert_eq!(
243            extract_path_pattern("/Users/alice/project/src/main.rs"),
244            Some("^/Users/alice/project/src/".to_string())
245        );
246        assert_eq!(
247            extract_path_pattern("src/lib.rs"),
248            Some("^src/".to_string())
249        );
250        assert_eq!(extract_path_pattern("file.txt"), None);
251        assert_eq!(extract_path_pattern("/file.txt"), None);
252    }
253
254    #[test]
255    fn test_extract_path_pattern_display() {
256        assert_eq!(
257            extract_path_pattern_display("/Users/alice/project/src/main.rs"),
258            Some("/Users/alice/project/src/".to_string())
259        );
260        assert_eq!(
261            extract_path_pattern_display("src/lib.rs"),
262            Some("src/".to_string())
263        );
264    }
265
266    #[test]
267    fn test_extract_url_pattern() {
268        assert_eq!(
269            extract_url_pattern("https://github.com/user/repo"),
270            Some("^https?://github\\.com".to_string())
271        );
272        assert_eq!(
273            extract_url_pattern("http://example.com/path?query=1"),
274            Some("^https?://example\\.com".to_string())
275        );
276        assert_eq!(extract_url_pattern("not a url"), None);
277    }
278
279    #[test]
280    fn test_extract_url_pattern_display() {
281        assert_eq!(
282            extract_url_pattern_display("https://github.com/user/repo"),
283            Some("github.com".to_string())
284        );
285        assert_eq!(
286            extract_url_pattern_display("http://api.example.com/v1/users"),
287            Some("api.example.com".to_string())
288        );
289    }
290
291    #[test]
292    fn test_special_chars_are_escaped() {
293        assert_eq!(
294            extract_path_pattern("/path/with (parens)/file.txt"),
295            Some("^/path/with \\(parens\\)/".to_string())
296        );
297        assert_eq!(
298            extract_url_pattern("https://test.example.com/path"),
299            Some("^https?://test\\.example\\.com".to_string())
300        );
301    }
302
303    #[test]
304    fn test_extract_copy_move_pattern_same_directory() {
305        assert_eq!(
306            extract_copy_move_pattern(
307                "/Users/alice/project/src/old.rs\n/Users/alice/project/src/new.rs"
308            ),
309            Some("^/Users/alice/project/src/".to_string())
310        );
311    }
312
313    #[test]
314    fn test_extract_copy_move_pattern_sibling_directories() {
315        assert_eq!(
316            extract_copy_move_pattern(
317                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
318            ),
319            Some("^/Users/alice/project/".to_string())
320        );
321    }
322
323    #[test]
324    fn test_extract_copy_move_pattern_no_common_prefix() {
325        assert_eq!(
326            extract_copy_move_pattern("/home/file.txt\n/tmp/file.txt"),
327            None
328        );
329    }
330
331    #[test]
332    fn test_extract_copy_move_pattern_relative_paths() {
333        assert_eq!(
334            extract_copy_move_pattern("src/old.rs\nsrc/new.rs"),
335            Some("^src/".to_string())
336        );
337    }
338
339    #[test]
340    fn test_extract_copy_move_pattern_display() {
341        assert_eq!(
342            extract_copy_move_pattern_display(
343                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
344            ),
345            Some("/Users/alice/project/".to_string())
346        );
347    }
348
349    #[test]
350    fn test_extract_copy_move_pattern_no_arrow() {
351        assert_eq!(extract_copy_move_pattern("just/a/path.rs"), None);
352        assert_eq!(extract_copy_move_pattern_display("just/a/path.rs"), None);
353    }
354}