1use shell_command_parser::extract_terminal_command_prefix;
2use std::path::{Path, PathBuf};
3use url::Url;
4
5/// Normalize path separators to forward slashes for consistent cross-platform patterns.
6fn normalize_separators(path_str: &str) -> String {
7 path_str.replace('\\', "/")
8}
9
10/// Returns true if the token looks like a command name or subcommand — i.e. it
11/// contains only alphanumeric characters, hyphens, and underscores, and does not
12/// start with a hyphen (which would make it a flag).
13fn is_plain_command_token(token: &str) -> bool {
14 !token.starts_with('-')
15 && token
16 .chars()
17 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
18}
19
20struct CommandPrefix {
21 normalized_tokens: Vec<String>,
22 display: String,
23}
24
25/// Extracts the command name and optional subcommand from a shell command using
26/// the shell parser.
27///
28/// This parses the command properly to extract the command name and optional
29/// subcommand (e.g. "cargo" and "test" from "cargo test -p search"), handling shell
30/// syntax correctly. Returns `None` if parsing fails or if the command name
31/// contains path separators (for security reasons).
32fn extract_command_prefix(command: &str) -> Option<CommandPrefix> {
33 let prefix = extract_terminal_command_prefix(command)?;
34
35 if !is_plain_command_token(&prefix.command) {
36 return None;
37 }
38
39 Some(CommandPrefix {
40 normalized_tokens: prefix.tokens,
41 display: prefix.display,
42 })
43}
44
45/// Extracts a regex pattern from a terminal command based on the first token (command name).
46///
47/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
48/// This is a deliberate security decision: we only allow pattern-based "always allow"
49/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
50/// scripts or absolute paths which could be manipulated by an attacker.
51pub fn extract_terminal_pattern(command: &str) -> Option<String> {
52 let prefix = extract_command_prefix(command)?;
53 let tokens = prefix.normalized_tokens;
54
55 match tokens.as_slice() {
56 [] => None,
57 [single] => Some(format!("^{}\\b", regex::escape(single))),
58 [rest @ .., last] => Some(format!(
59 "^{}\\s+{}(\\s|$)",
60 rest.iter()
61 .map(|token| regex::escape(token))
62 .collect::<Vec<_>>()
63 .join("\\s+"),
64 regex::escape(last)
65 )),
66 }
67}
68
69pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
70 let prefix = extract_command_prefix(command)?;
71 Some(prefix.display)
72}
73
74pub fn extract_path_pattern(path: &str) -> Option<String> {
75 let parent = Path::new(path).parent()?;
76 let parent_str = normalize_separators(parent.to_str()?);
77 if parent_str.is_empty() || parent_str == "/" {
78 return None;
79 }
80 Some(format!("^{}/", regex::escape(&parent_str)))
81}
82
83pub fn extract_path_pattern_display(path: &str) -> Option<String> {
84 let parent = Path::new(path).parent()?;
85 let parent_str = normalize_separators(parent.to_str()?);
86 if parent_str.is_empty() || parent_str == "/" {
87 return None;
88 }
89 Some(format!("{}/", parent_str))
90}
91
92fn common_parent_dir(path_a: &str, path_b: &str) -> Option<PathBuf> {
93 let parent_a = Path::new(path_a).parent()?;
94 let parent_b = Path::new(path_b).parent()?;
95
96 let components_a: Vec<_> = parent_a.components().collect();
97 let components_b: Vec<_> = parent_b.components().collect();
98
99 let common_count = components_a
100 .iter()
101 .zip(components_b.iter())
102 .take_while(|(a, b)| a == b)
103 .count();
104
105 if common_count == 0 {
106 return None;
107 }
108
109 let common: PathBuf = components_a[..common_count].iter().collect();
110 Some(common)
111}
112
113pub fn extract_copy_move_pattern(input: &str) -> Option<String> {
114 let (source, dest) = input.split_once('\n')?;
115 let common = common_parent_dir(source, dest)?;
116 let common_str = normalize_separators(common.to_str()?);
117 if common_str.is_empty() || common_str == "/" {
118 return None;
119 }
120 Some(format!("^{}/", regex::escape(&common_str)))
121}
122
123pub fn extract_copy_move_pattern_display(input: &str) -> Option<String> {
124 let (source, dest) = input.split_once('\n')?;
125 let common = common_parent_dir(source, dest)?;
126 let common_str = normalize_separators(common.to_str()?);
127 if common_str.is_empty() || common_str == "/" {
128 return None;
129 }
130 Some(format!("{}/", common_str))
131}
132
133pub fn extract_url_pattern(url: &str) -> Option<String> {
134 let parsed = Url::parse(url).ok()?;
135 let domain = parsed.host_str()?;
136 Some(format!("^https?://{}", regex::escape(domain)))
137}
138
139pub fn extract_url_pattern_display(url: &str) -> Option<String> {
140 let parsed = Url::parse(url).ok()?;
141 let domain = parsed.host_str()?;
142 Some(domain.to_string())
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_extract_terminal_pattern() {
151 assert_eq!(
152 extract_terminal_pattern("cargo build --release"),
153 Some("^cargo\\s+build(\\s|$)".to_string())
154 );
155 assert_eq!(
156 extract_terminal_pattern("cargo test -p search"),
157 Some("^cargo\\s+test(\\s|$)".to_string())
158 );
159 assert_eq!(
160 extract_terminal_pattern("npm install"),
161 Some("^npm\\s+install(\\s|$)".to_string())
162 );
163 assert_eq!(
164 extract_terminal_pattern("git-lfs pull"),
165 Some("^git\\-lfs\\s+pull(\\s|$)".to_string())
166 );
167 assert_eq!(
168 extract_terminal_pattern("my_script arg"),
169 Some("^my_script\\s+arg(\\s|$)".to_string())
170 );
171
172 // Flags as second token: only the command name is used
173 assert_eq!(
174 extract_terminal_pattern("ls -la"),
175 Some("^ls\\b".to_string())
176 );
177 assert_eq!(
178 extract_terminal_pattern("rm --force foo"),
179 Some("^rm\\b".to_string())
180 );
181
182 // Single-word commands
183 assert_eq!(extract_terminal_pattern("ls"), Some("^ls\\b".to_string()));
184
185 // Subcommand pattern does not match a hyphenated extension of the subcommand
186 // (e.g. approving "cargo build" should not approve "cargo build-foo")
187 assert_eq!(
188 extract_terminal_pattern("cargo build"),
189 Some("^cargo\\s+build(\\s|$)".to_string())
190 );
191 let pattern = regex::Regex::new(&extract_terminal_pattern("cargo build").unwrap()).unwrap();
192 assert!(pattern.is_match("cargo build --release"));
193 assert!(pattern.is_match("cargo build"));
194 assert!(!pattern.is_match("cargo build-foo"));
195 assert!(!pattern.is_match("cargo builder"));
196
197 // Env-var prefixes are included in generated patterns
198 assert_eq!(
199 extract_terminal_pattern("PAGER=blah git log --oneline"),
200 Some("^PAGER=blah\\s+git\\s+log(\\s|$)".to_string())
201 );
202 assert_eq!(
203 extract_terminal_pattern("A=1 B=2 git log"),
204 Some("^A=1\\s+B=2\\s+git\\s+log(\\s|$)".to_string())
205 );
206 assert_eq!(
207 extract_terminal_pattern("PAGER='less -R' git log"),
208 Some("^PAGER='less \\-R'\\s+git\\s+log(\\s|$)".to_string())
209 );
210
211 // Path-like commands are rejected
212 assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
213 assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
214 assert_eq!(extract_terminal_pattern("PAGER=blah ./script.sh arg"), None);
215 }
216
217 #[test]
218 fn test_extract_terminal_pattern_display() {
219 assert_eq!(
220 extract_terminal_pattern_display("cargo build --release"),
221 Some("cargo build".to_string())
222 );
223 assert_eq!(
224 extract_terminal_pattern_display("cargo test -p search"),
225 Some("cargo test".to_string())
226 );
227 assert_eq!(
228 extract_terminal_pattern_display("npm install"),
229 Some("npm install".to_string())
230 );
231 assert_eq!(
232 extract_terminal_pattern_display("ls -la"),
233 Some("ls".to_string())
234 );
235 assert_eq!(
236 extract_terminal_pattern_display("ls"),
237 Some("ls".to_string())
238 );
239 assert_eq!(
240 extract_terminal_pattern_display("PAGER=blah git log --oneline"),
241 Some("PAGER=blah git log".to_string())
242 );
243 assert_eq!(
244 extract_terminal_pattern_display("PAGER='less -R' git log"),
245 Some("PAGER='less -R' git log".to_string())
246 );
247 }
248
249 #[test]
250 fn test_terminal_pattern_regex_normalizes_whitespace() {
251 let pattern = extract_terminal_pattern("PAGER=blah git log --oneline")
252 .expect("expected terminal pattern");
253 let regex = regex::Regex::new(&pattern).expect("expected valid regex");
254
255 assert!(regex.is_match("PAGER=blah git log"));
256 assert!(regex.is_match("PAGER=blah git log --stat"));
257 }
258
259 #[test]
260 fn test_extract_terminal_pattern_skips_redirects_before_subcommand() {
261 assert_eq!(
262 extract_terminal_pattern("git 2>/dev/null log --oneline"),
263 Some("^git\\s+log(\\s|$)".to_string())
264 );
265 assert_eq!(
266 extract_terminal_pattern_display("git 2>/dev/null log --oneline"),
267 Some("git 2>/dev/null log".to_string())
268 );
269
270 assert_eq!(
271 extract_terminal_pattern("rm --force foo"),
272 Some("^rm\\b".to_string())
273 );
274 }
275
276 #[test]
277 fn test_extract_path_pattern() {
278 assert_eq!(
279 extract_path_pattern("/Users/alice/project/src/main.rs"),
280 Some("^/Users/alice/project/src/".to_string())
281 );
282 assert_eq!(
283 extract_path_pattern("src/lib.rs"),
284 Some("^src/".to_string())
285 );
286 assert_eq!(extract_path_pattern("file.txt"), None);
287 assert_eq!(extract_path_pattern("/file.txt"), None);
288 }
289
290 #[test]
291 fn test_extract_path_pattern_display() {
292 assert_eq!(
293 extract_path_pattern_display("/Users/alice/project/src/main.rs"),
294 Some("/Users/alice/project/src/".to_string())
295 );
296 assert_eq!(
297 extract_path_pattern_display("src/lib.rs"),
298 Some("src/".to_string())
299 );
300 }
301
302 #[test]
303 fn test_extract_url_pattern() {
304 assert_eq!(
305 extract_url_pattern("https://github.com/user/repo"),
306 Some("^https?://github\\.com".to_string())
307 );
308 assert_eq!(
309 extract_url_pattern("http://example.com/path?query=1"),
310 Some("^https?://example\\.com".to_string())
311 );
312 assert_eq!(extract_url_pattern("not a url"), None);
313 }
314
315 #[test]
316 fn test_extract_url_pattern_display() {
317 assert_eq!(
318 extract_url_pattern_display("https://github.com/user/repo"),
319 Some("github.com".to_string())
320 );
321 assert_eq!(
322 extract_url_pattern_display("http://api.example.com/v1/users"),
323 Some("api.example.com".to_string())
324 );
325 }
326
327 #[test]
328 fn test_special_chars_are_escaped() {
329 assert_eq!(
330 extract_path_pattern("/path/with (parens)/file.txt"),
331 Some("^/path/with \\(parens\\)/".to_string())
332 );
333 assert_eq!(
334 extract_url_pattern("https://test.example.com/path"),
335 Some("^https?://test\\.example\\.com".to_string())
336 );
337 }
338
339 #[test]
340 fn test_extract_copy_move_pattern_same_directory() {
341 assert_eq!(
342 extract_copy_move_pattern(
343 "/Users/alice/project/src/old.rs\n/Users/alice/project/src/new.rs"
344 ),
345 Some("^/Users/alice/project/src/".to_string())
346 );
347 }
348
349 #[test]
350 fn test_extract_copy_move_pattern_sibling_directories() {
351 assert_eq!(
352 extract_copy_move_pattern(
353 "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
354 ),
355 Some("^/Users/alice/project/".to_string())
356 );
357 }
358
359 #[test]
360 fn test_extract_copy_move_pattern_no_common_prefix() {
361 assert_eq!(
362 extract_copy_move_pattern("/home/file.txt\n/tmp/file.txt"),
363 None
364 );
365 }
366
367 #[test]
368 fn test_extract_copy_move_pattern_relative_paths() {
369 assert_eq!(
370 extract_copy_move_pattern("src/old.rs\nsrc/new.rs"),
371 Some("^src/".to_string())
372 );
373 }
374
375 #[test]
376 fn test_extract_copy_move_pattern_display() {
377 assert_eq!(
378 extract_copy_move_pattern_display(
379 "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
380 ),
381 Some("/Users/alice/project/".to_string())
382 );
383 }
384
385 #[test]
386 fn test_extract_copy_move_pattern_no_arrow() {
387 assert_eq!(extract_copy_move_pattern("just/a/path.rs"), None);
388 assert_eq!(extract_copy_move_pattern_display("just/a/path.rs"), None);
389 }
390}