1use acp_thread::PermissionPattern;
2use shell_command_parser::{extract_commands, extract_terminal_command_prefix};
3use std::path::{Path, PathBuf};
4use url::Url;
5
6/// Normalize path separators to forward slashes for consistent cross-platform patterns.
7fn normalize_separators(path_str: &str) -> String {
8 path_str.replace('\\', "/")
9}
10
11/// Returns true if the token looks like a command name or subcommand — i.e. it
12/// contains only alphanumeric characters, hyphens, and underscores, and does not
13/// start with a hyphen (which would make it a flag).
14fn is_plain_command_token(token: &str) -> bool {
15 !token.starts_with('-')
16 && token
17 .chars()
18 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
19}
20
21struct CommandPrefix {
22 normalized_tokens: Vec<String>,
23 display: String,
24}
25
26/// Extracts the command name and optional subcommand from a shell command using
27/// the shell parser.
28///
29/// This parses the command properly to extract the command name and optional
30/// subcommand (e.g. "cargo" and "test" from "cargo test -p search"), handling shell
31/// syntax correctly. Returns `None` if parsing fails or if the command name
32/// contains path separators (for security reasons).
33fn extract_command_prefix(command: &str) -> Option<CommandPrefix> {
34 let prefix = extract_terminal_command_prefix(command)?;
35
36 if !is_plain_command_token(&prefix.command) {
37 return None;
38 }
39
40 Some(CommandPrefix {
41 normalized_tokens: prefix.tokens,
42 display: prefix.display,
43 })
44}
45
46/// Extracts a regex pattern and display name from a terminal command.
47///
48/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
49/// This is a deliberate security decision: we only allow pattern-based "always allow"
50/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
51/// scripts or absolute paths which could be manipulated by an attacker.
52pub fn extract_terminal_permission_pattern(command: &str) -> Option<PermissionPattern> {
53 let pattern = extract_terminal_pattern(command)?;
54 let display_name = extract_terminal_pattern_display(command)?;
55 Some(PermissionPattern {
56 pattern,
57 display_name,
58 })
59}
60
61pub fn extract_terminal_pattern(command: &str) -> Option<String> {
62 let prefix = extract_command_prefix(command)?;
63 let tokens = prefix.normalized_tokens;
64
65 match tokens.as_slice() {
66 [] => None,
67 [single] => Some(format!("^{}\\b", regex::escape(single))),
68 [rest @ .., last] => Some(format!(
69 "^{}\\s+{}(\\s|$)",
70 rest.iter()
71 .map(|token| regex::escape(token))
72 .collect::<Vec<_>>()
73 .join("\\s+"),
74 regex::escape(last)
75 )),
76 }
77}
78
79pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
80 let prefix = extract_command_prefix(command)?;
81 Some(prefix.display)
82}
83
84/// Extracts patterns for ALL commands in a pipeline, not just the first one.
85///
86/// For a command like `"cargo test 2>&1 | tail"`, this returns patterns for
87/// both `cargo` and `tail`. Path-based commands (e.g. `./script.sh`) are
88/// filtered out, and duplicate command names are deduplicated while preserving
89/// order.
90pub fn extract_all_terminal_patterns(command: &str) -> Vec<PermissionPattern> {
91 let commands = match extract_commands(command) {
92 Some(commands) => commands,
93 None => return Vec::new(),
94 };
95
96 let mut results = Vec::new();
97
98 for cmd in &commands {
99 let Some(permission_pattern) = extract_terminal_permission_pattern(cmd) else {
100 continue;
101 };
102
103 if results.contains(&permission_pattern) {
104 continue;
105 }
106
107 results.push(permission_pattern);
108 }
109
110 results
111}
112
113pub fn extract_path_pattern(path: &str) -> Option<String> {
114 let parent = Path::new(path).parent()?;
115 let parent_str = normalize_separators(parent.to_str()?);
116 if parent_str.is_empty() || parent_str == "/" {
117 return None;
118 }
119 Some(format!("^{}/", regex::escape(&parent_str)))
120}
121
122pub fn extract_path_pattern_display(path: &str) -> Option<String> {
123 let parent = Path::new(path).parent()?;
124 let parent_str = normalize_separators(parent.to_str()?);
125 if parent_str.is_empty() || parent_str == "/" {
126 return None;
127 }
128 Some(format!("{}/", parent_str))
129}
130
131fn common_parent_dir(path_a: &str, path_b: &str) -> Option<PathBuf> {
132 let parent_a = Path::new(path_a).parent()?;
133 let parent_b = Path::new(path_b).parent()?;
134
135 let components_a: Vec<_> = parent_a.components().collect();
136 let components_b: Vec<_> = parent_b.components().collect();
137
138 let common_count = components_a
139 .iter()
140 .zip(components_b.iter())
141 .take_while(|(a, b)| a == b)
142 .count();
143
144 if common_count == 0 {
145 return None;
146 }
147
148 let common: PathBuf = components_a[..common_count].iter().collect();
149 Some(common)
150}
151
152pub fn extract_copy_move_pattern(input: &str) -> Option<String> {
153 let (source, dest) = input.split_once('\n')?;
154 let common = common_parent_dir(source, dest)?;
155 let common_str = normalize_separators(common.to_str()?);
156 if common_str.is_empty() || common_str == "/" {
157 return None;
158 }
159 Some(format!("^{}/", regex::escape(&common_str)))
160}
161
162pub fn extract_copy_move_pattern_display(input: &str) -> Option<String> {
163 let (source, dest) = input.split_once('\n')?;
164 let common = common_parent_dir(source, dest)?;
165 let common_str = normalize_separators(common.to_str()?);
166 if common_str.is_empty() || common_str == "/" {
167 return None;
168 }
169 Some(format!("{}/", common_str))
170}
171
172pub fn extract_url_pattern(url: &str) -> Option<String> {
173 let parsed = Url::parse(url).ok()?;
174 let domain = parsed.host_str()?;
175 Some(format!("^https?://{}", regex::escape(domain)))
176}
177
178pub fn extract_url_pattern_display(url: &str) -> Option<String> {
179 let parsed = Url::parse(url).ok()?;
180 let domain = parsed.host_str()?;
181 Some(domain.to_string())
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_extract_terminal_pattern() {
190 assert_eq!(
191 extract_terminal_pattern("cargo build --release"),
192 Some("^cargo\\s+build(\\s|$)".to_string())
193 );
194 assert_eq!(
195 extract_terminal_pattern("cargo test -p search"),
196 Some("^cargo\\s+test(\\s|$)".to_string())
197 );
198 assert_eq!(
199 extract_terminal_pattern("npm install"),
200 Some("^npm\\s+install(\\s|$)".to_string())
201 );
202 assert_eq!(
203 extract_terminal_pattern("git-lfs pull"),
204 Some("^git\\-lfs\\s+pull(\\s|$)".to_string())
205 );
206 assert_eq!(
207 extract_terminal_pattern("my_script arg"),
208 Some("^my_script\\s+arg(\\s|$)".to_string())
209 );
210
211 // Flags as second token: only the command name is used
212 assert_eq!(
213 extract_terminal_pattern("ls -la"),
214 Some("^ls\\b".to_string())
215 );
216 assert_eq!(
217 extract_terminal_pattern("rm --force foo"),
218 Some("^rm\\b".to_string())
219 );
220
221 // Single-word commands
222 assert_eq!(extract_terminal_pattern("ls"), Some("^ls\\b".to_string()));
223
224 // Subcommand pattern does not match a hyphenated extension of the subcommand
225 // (e.g. approving "cargo build" should not approve "cargo build-foo")
226 assert_eq!(
227 extract_terminal_pattern("cargo build"),
228 Some("^cargo\\s+build(\\s|$)".to_string())
229 );
230 let pattern = regex::Regex::new(&extract_terminal_pattern("cargo build").unwrap()).unwrap();
231 assert!(pattern.is_match("cargo build --release"));
232 assert!(pattern.is_match("cargo build"));
233 assert!(!pattern.is_match("cargo build-foo"));
234 assert!(!pattern.is_match("cargo builder"));
235
236 // Env-var prefixes are included in generated patterns
237 assert_eq!(
238 extract_terminal_pattern("PAGER=blah git log --oneline"),
239 Some("^PAGER=blah\\s+git\\s+log(\\s|$)".to_string())
240 );
241 assert_eq!(
242 extract_terminal_pattern("A=1 B=2 git log"),
243 Some("^A=1\\s+B=2\\s+git\\s+log(\\s|$)".to_string())
244 );
245 assert_eq!(
246 extract_terminal_pattern("PAGER='less -R' git log"),
247 Some("^PAGER='less \\-R'\\s+git\\s+log(\\s|$)".to_string())
248 );
249
250 // Path-like commands are rejected
251 assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
252 assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
253 assert_eq!(extract_terminal_pattern("PAGER=blah ./script.sh arg"), None);
254 }
255
256 #[test]
257 fn test_extract_terminal_pattern_display() {
258 assert_eq!(
259 extract_terminal_pattern_display("cargo build --release"),
260 Some("cargo build".to_string())
261 );
262 assert_eq!(
263 extract_terminal_pattern_display("cargo test -p search"),
264 Some("cargo test".to_string())
265 );
266 assert_eq!(
267 extract_terminal_pattern_display("npm install"),
268 Some("npm install".to_string())
269 );
270 assert_eq!(
271 extract_terminal_pattern_display("ls -la"),
272 Some("ls".to_string())
273 );
274 assert_eq!(
275 extract_terminal_pattern_display("ls"),
276 Some("ls".to_string())
277 );
278 assert_eq!(
279 extract_terminal_pattern_display("PAGER=blah git log --oneline"),
280 Some("PAGER=blah git log".to_string())
281 );
282 assert_eq!(
283 extract_terminal_pattern_display("PAGER='less -R' git log"),
284 Some("PAGER='less -R' git log".to_string())
285 );
286 }
287
288 #[test]
289 fn test_terminal_pattern_regex_normalizes_whitespace() {
290 let pattern = extract_terminal_pattern("PAGER=blah git log --oneline")
291 .expect("expected terminal pattern");
292 let regex = regex::Regex::new(&pattern).expect("expected valid regex");
293
294 assert!(regex.is_match("PAGER=blah git log"));
295 assert!(regex.is_match("PAGER=blah git log --stat"));
296 }
297
298 #[test]
299 fn test_extract_terminal_pattern_skips_redirects_before_subcommand() {
300 assert_eq!(
301 extract_terminal_pattern("git 2>/dev/null log --oneline"),
302 Some("^git\\s+log(\\s|$)".to_string())
303 );
304 assert_eq!(
305 extract_terminal_pattern_display("git 2>/dev/null log --oneline"),
306 Some("git 2>/dev/null log".to_string())
307 );
308
309 assert_eq!(
310 extract_terminal_pattern("rm --force foo"),
311 Some("^rm\\b".to_string())
312 );
313 }
314
315 #[test]
316 fn test_extract_all_terminal_patterns_pipeline() {
317 assert_eq!(
318 extract_all_terminal_patterns("cargo test 2>&1 | tail"),
319 vec![
320 PermissionPattern {
321 pattern: "^cargo\\s+test(\\s|$)".to_string(),
322 display_name: "cargo test".to_string(),
323 },
324 PermissionPattern {
325 pattern: "^tail\\b".to_string(),
326 display_name: "tail".to_string(),
327 },
328 ]
329 );
330 }
331
332 #[test]
333 fn test_extract_all_terminal_patterns_with_path_commands() {
334 assert_eq!(
335 extract_all_terminal_patterns("./script.sh | grep foo"),
336 vec![PermissionPattern {
337 pattern: "^grep\\s+foo(\\s|$)".to_string(),
338 display_name: "grep foo".to_string(),
339 }]
340 );
341 }
342
343 #[test]
344 fn test_extract_all_terminal_patterns_all_paths() {
345 assert_eq!(extract_all_terminal_patterns("./a.sh | /usr/bin/b"), vec![]);
346 }
347
348 #[test]
349 fn test_extract_path_pattern() {
350 assert_eq!(
351 extract_path_pattern("/Users/alice/project/src/main.rs"),
352 Some("^/Users/alice/project/src/".to_string())
353 );
354 assert_eq!(
355 extract_path_pattern("src/lib.rs"),
356 Some("^src/".to_string())
357 );
358 assert_eq!(extract_path_pattern("file.txt"), None);
359 assert_eq!(extract_path_pattern("/file.txt"), None);
360 }
361
362 #[test]
363 fn test_extract_path_pattern_display() {
364 assert_eq!(
365 extract_path_pattern_display("/Users/alice/project/src/main.rs"),
366 Some("/Users/alice/project/src/".to_string())
367 );
368 assert_eq!(
369 extract_path_pattern_display("src/lib.rs"),
370 Some("src/".to_string())
371 );
372 }
373
374 #[test]
375 fn test_extract_url_pattern() {
376 assert_eq!(
377 extract_url_pattern("https://github.com/user/repo"),
378 Some("^https?://github\\.com".to_string())
379 );
380 assert_eq!(
381 extract_url_pattern("http://example.com/path?query=1"),
382 Some("^https?://example\\.com".to_string())
383 );
384 assert_eq!(extract_url_pattern("not a url"), None);
385 }
386
387 #[test]
388 fn test_extract_url_pattern_display() {
389 assert_eq!(
390 extract_url_pattern_display("https://github.com/user/repo"),
391 Some("github.com".to_string())
392 );
393 assert_eq!(
394 extract_url_pattern_display("http://api.example.com/v1/users"),
395 Some("api.example.com".to_string())
396 );
397 }
398
399 #[test]
400 fn test_special_chars_are_escaped() {
401 assert_eq!(
402 extract_path_pattern("/path/with (parens)/file.txt"),
403 Some("^/path/with \\(parens\\)/".to_string())
404 );
405 assert_eq!(
406 extract_url_pattern("https://test.example.com/path"),
407 Some("^https?://test\\.example\\.com".to_string())
408 );
409 }
410
411 #[test]
412 fn test_extract_copy_move_pattern_same_directory() {
413 assert_eq!(
414 extract_copy_move_pattern(
415 "/Users/alice/project/src/old.rs\n/Users/alice/project/src/new.rs"
416 ),
417 Some("^/Users/alice/project/src/".to_string())
418 );
419 }
420
421 #[test]
422 fn test_extract_copy_move_pattern_sibling_directories() {
423 assert_eq!(
424 extract_copy_move_pattern(
425 "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
426 ),
427 Some("^/Users/alice/project/".to_string())
428 );
429 }
430
431 #[test]
432 fn test_extract_copy_move_pattern_no_common_prefix() {
433 assert_eq!(
434 extract_copy_move_pattern("/home/file.txt\n/tmp/file.txt"),
435 None
436 );
437 }
438
439 #[test]
440 fn test_extract_copy_move_pattern_relative_paths() {
441 assert_eq!(
442 extract_copy_move_pattern("src/old.rs\nsrc/new.rs"),
443 Some("^src/".to_string())
444 );
445 }
446
447 #[test]
448 fn test_extract_copy_move_pattern_display() {
449 assert_eq!(
450 extract_copy_move_pattern_display(
451 "/Users/alice/project/src/old.rs\n/Users/alice/project/dst/new.rs"
452 ),
453 Some("/Users/alice/project/".to_string())
454 );
455 }
456
457 #[test]
458 fn test_extract_copy_move_pattern_no_arrow() {
459 assert_eq!(extract_copy_move_pattern("just/a/path.rs"), None);
460 assert_eq!(extract_copy_move_pattern_display("just/a/path.rs"), None);
461 }
462}