pattern_extraction.rs

  1use shell_command_parser::extract_commands;
  2use std::path::{Path, PathBuf};
  3use url::Url;
  4
  5/// Normalize path separators to forward slashes for consistent cross-platform patterns.
  6fn normalize_separators(path_str: &str) -> String {
  7    path_str.replace('\\', "/")
  8}
  9
 10/// Extracts the command name from a shell command using the shell parser.
 11///
 12/// This parses the command properly to extract just the command name (first word),
 13/// handling shell syntax correctly. Returns `None` if parsing fails or if the
 14/// command name contains path separators (for security reasons).
 15fn extract_command_name(command: &str) -> Option<String> {
 16    let commands = extract_commands(command)?;
 17    let first_command = commands.first()?;
 18
 19    let first_token = first_command.split_whitespace().next()?;
 20
 21    // Only allow alphanumeric commands with hyphens/underscores.
 22    // Reject paths like "./script.sh" or "/usr/bin/python" to prevent
 23    // users from accidentally allowing arbitrary script execution.
 24    if first_token
 25        .chars()
 26        .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
 27    {
 28        Some(first_token.to_string())
 29    } else {
 30        None
 31    }
 32}
 33
 34/// Extracts a regex pattern from a terminal command based on the first token (command name).
 35///
 36/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
 37/// This is a deliberate security decision: we only allow pattern-based "always allow"
 38/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
 39/// scripts or absolute paths which could be manipulated by an attacker.
 40pub fn extract_terminal_pattern(command: &str) -> Option<String> {
 41    let command_name = extract_command_name(command)?;
 42    Some(format!("^{}\\b", regex::escape(&command_name)))
 43}
 44
 45pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
 46    extract_command_name(command)
 47}
 48
 49pub fn extract_path_pattern(path: &str) -> Option<String> {
 50    let parent = Path::new(path).parent()?;
 51    let parent_str = normalize_separators(parent.to_str()?);
 52    if parent_str.is_empty() || parent_str == "/" {
 53        return None;
 54    }
 55    Some(format!("^{}/", regex::escape(&parent_str)))
 56}
 57
 58pub fn extract_path_pattern_display(path: &str) -> Option<String> {
 59    let parent = Path::new(path).parent()?;
 60    let parent_str = normalize_separators(parent.to_str()?);
 61    if parent_str.is_empty() || parent_str == "/" {
 62        return None;
 63    }
 64    Some(format!("{}/", parent_str))
 65}
 66
 67fn common_parent_dir(path_a: &str, path_b: &str) -> Option<PathBuf> {
 68    let parent_a = Path::new(path_a).parent()?;
 69    let parent_b = Path::new(path_b).parent()?;
 70
 71    let components_a: Vec<_> = parent_a.components().collect();
 72    let components_b: Vec<_> = parent_b.components().collect();
 73
 74    let common_count = components_a
 75        .iter()
 76        .zip(components_b.iter())
 77        .take_while(|(a, b)| a == b)
 78        .count();
 79
 80    if common_count == 0 {
 81        return None;
 82    }
 83
 84    let common: PathBuf = components_a[..common_count].iter().collect();
 85    Some(common)
 86}
 87
 88pub fn extract_copy_move_pattern(input: &str) -> Option<String> {
 89    let (source, dest) = input.split_once('\n')?;
 90    let common = common_parent_dir(source, dest)?;
 91    let common_str = normalize_separators(common.to_str()?);
 92    if common_str.is_empty() || common_str == "/" {
 93        return None;
 94    }
 95    Some(format!("^{}/", regex::escape(&common_str)))
 96}
 97
 98pub fn extract_copy_move_pattern_display(input: &str) -> Option<String> {
 99    let (source, dest) = input.split_once('\n')?;
100    let common = common_parent_dir(source, dest)?;
101    let common_str = normalize_separators(common.to_str()?);
102    if common_str.is_empty() || common_str == "/" {
103        return None;
104    }
105    Some(format!("{}/", common_str))
106}
107
108pub fn extract_url_pattern(url: &str) -> Option<String> {
109    let parsed = Url::parse(url).ok()?;
110    let domain = parsed.host_str()?;
111    Some(format!("^https?://{}", regex::escape(domain)))
112}
113
114pub fn extract_url_pattern_display(url: &str) -> Option<String> {
115    let parsed = Url::parse(url).ok()?;
116    let domain = parsed.host_str()?;
117    Some(domain.to_string())
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_extract_terminal_pattern() {
126        assert_eq!(
127            extract_terminal_pattern("cargo build --release"),
128            Some("^cargo\\b".to_string())
129        );
130        assert_eq!(
131            extract_terminal_pattern("npm install"),
132            Some("^npm\\b".to_string())
133        );
134        assert_eq!(
135            extract_terminal_pattern("git-lfs pull"),
136            Some("^git\\-lfs\\b".to_string())
137        );
138        assert_eq!(
139            extract_terminal_pattern("my_script arg"),
140            Some("^my_script\\b".to_string())
141        );
142        assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
143        assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
144    }
145
146    #[test]
147    fn test_extract_terminal_pattern_display() {
148        assert_eq!(
149            extract_terminal_pattern_display("cargo build --release"),
150            Some("cargo".to_string())
151        );
152        assert_eq!(
153            extract_terminal_pattern_display("npm install"),
154            Some("npm".to_string())
155        );
156    }
157
158    #[test]
159    fn test_extract_path_pattern() {
160        assert_eq!(
161            extract_path_pattern("/Users/alice/project/src/main.rs"),
162            Some("^/Users/alice/project/src/".to_string())
163        );
164        assert_eq!(
165            extract_path_pattern("src/lib.rs"),
166            Some("^src/".to_string())
167        );
168        assert_eq!(extract_path_pattern("file.txt"), None);
169        assert_eq!(extract_path_pattern("/file.txt"), None);
170    }
171
172    #[test]
173    fn test_extract_path_pattern_display() {
174        assert_eq!(
175            extract_path_pattern_display("/Users/alice/project/src/main.rs"),
176            Some("/Users/alice/project/src/".to_string())
177        );
178        assert_eq!(
179            extract_path_pattern_display("src/lib.rs"),
180            Some("src/".to_string())
181        );
182    }
183
184    #[test]
185    fn test_extract_url_pattern() {
186        assert_eq!(
187            extract_url_pattern("https://github.com/user/repo"),
188            Some("^https?://github\\.com".to_string())
189        );
190        assert_eq!(
191            extract_url_pattern("http://example.com/path?query=1"),
192            Some("^https?://example\\.com".to_string())
193        );
194        assert_eq!(extract_url_pattern("not a url"), None);
195    }
196
197    #[test]
198    fn test_extract_url_pattern_display() {
199        assert_eq!(
200            extract_url_pattern_display("https://github.com/user/repo"),
201            Some("github.com".to_string())
202        );
203        assert_eq!(
204            extract_url_pattern_display("http://api.example.com/v1/users"),
205            Some("api.example.com".to_string())
206        );
207    }
208
209    #[test]
210    fn test_special_chars_are_escaped() {
211        assert_eq!(
212            extract_path_pattern("/path/with (parens)/file.txt"),
213            Some("^/path/with \\(parens\\)/".to_string())
214        );
215        assert_eq!(
216            extract_url_pattern("https://test.example.com/path"),
217            Some("^https?://test\\.example\\.com".to_string())
218        );
219    }
220
221    #[test]
222    fn test_extract_copy_move_pattern_same_directory() {
223        assert_eq!(
224            extract_copy_move_pattern(
225                "/Users/alice/project/src/old.rs\n/Users/alice/project/src/new.rs"
226            ),
227            Some("^/Users/alice/project/src/".to_string())
228        );
229    }
230
231    #[test]
232    fn test_extract_copy_move_pattern_sibling_directories() {
233        assert_eq!(
234            extract_copy_move_pattern(
235                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
236            ),
237            Some("^/Users/alice/project/".to_string())
238        );
239    }
240
241    #[test]
242    fn test_extract_copy_move_pattern_no_common_prefix() {
243        assert_eq!(
244            extract_copy_move_pattern("/home/file.txt\n/tmp/file.txt"),
245            None
246        );
247    }
248
249    #[test]
250    fn test_extract_copy_move_pattern_relative_paths() {
251        assert_eq!(
252            extract_copy_move_pattern("src/old.rs\nsrc/new.rs"),
253            Some("^src/".to_string())
254        );
255    }
256
257    #[test]
258    fn test_extract_copy_move_pattern_display() {
259        assert_eq!(
260            extract_copy_move_pattern_display(
261                "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
262            ),
263            Some("/Users/alice/project/".to_string())
264        );
265    }
266
267    #[test]
268    fn test_extract_copy_move_pattern_no_arrow() {
269        assert_eq!(extract_copy_move_pattern("just/a/path.rs"), None);
270        assert_eq!(extract_copy_move_pattern_display("just/a/path.rs"), None);
271    }
272}