1use crate::shell_parser::extract_commands;
2use url::Url;
3
4/// Extracts the command name from a shell command using the shell parser.
5///
6/// This parses the command properly to extract just the command name (first word),
7/// handling shell syntax correctly. Returns `None` if parsing fails or if the
8/// command name contains path separators (for security reasons).
9fn extract_command_name(command: &str) -> Option<String> {
10 let commands = extract_commands(command)?;
11 let first_command = commands.first()?;
12
13 let first_token = first_command.split_whitespace().next()?;
14
15 // Only allow alphanumeric commands with hyphens/underscores.
16 // Reject paths like "./script.sh" or "/usr/bin/python" to prevent
17 // users from accidentally allowing arbitrary script execution.
18 if first_token
19 .chars()
20 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
21 {
22 Some(first_token.to_string())
23 } else {
24 None
25 }
26}
27
28/// Extracts a regex pattern from a terminal command based on the first token (command name).
29///
30/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
31/// This is a deliberate security decision: we only allow pattern-based "always allow"
32/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
33/// scripts or absolute paths which could be manipulated by an attacker.
34pub fn extract_terminal_pattern(command: &str) -> Option<String> {
35 let command_name = extract_command_name(command)?;
36 Some(format!("^{}\\b", regex::escape(&command_name)))
37}
38
39pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
40 extract_command_name(command)
41}
42
43pub fn extract_path_pattern(path: &str) -> Option<String> {
44 let parent = std::path::Path::new(path).parent()?;
45 let parent_str = parent.to_str()?;
46 if parent_str.is_empty() || parent_str == "/" {
47 return None;
48 }
49 Some(format!("^{}/", regex::escape(parent_str)))
50}
51
52pub fn extract_path_pattern_display(path: &str) -> Option<String> {
53 let parent = std::path::Path::new(path).parent()?;
54 let parent_str = parent.to_str()?;
55 if parent_str.is_empty() || parent_str == "/" {
56 return None;
57 }
58 Some(format!("{}/", parent_str))
59}
60
61pub fn extract_url_pattern(url: &str) -> Option<String> {
62 let parsed = Url::parse(url).ok()?;
63 let domain = parsed.host_str()?;
64 Some(format!("^https?://{}", regex::escape(domain)))
65}
66
67pub fn extract_url_pattern_display(url: &str) -> Option<String> {
68 let parsed = Url::parse(url).ok()?;
69 let domain = parsed.host_str()?;
70 Some(domain.to_string())
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[test]
78 fn test_extract_terminal_pattern() {
79 assert_eq!(
80 extract_terminal_pattern("cargo build --release"),
81 Some("^cargo\\b".to_string())
82 );
83 assert_eq!(
84 extract_terminal_pattern("npm install"),
85 Some("^npm\\b".to_string())
86 );
87 assert_eq!(
88 extract_terminal_pattern("git-lfs pull"),
89 Some("^git\\-lfs\\b".to_string())
90 );
91 assert_eq!(
92 extract_terminal_pattern("my_script arg"),
93 Some("^my_script\\b".to_string())
94 );
95 assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
96 assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
97 }
98
99 #[test]
100 fn test_extract_terminal_pattern_display() {
101 assert_eq!(
102 extract_terminal_pattern_display("cargo build --release"),
103 Some("cargo".to_string())
104 );
105 assert_eq!(
106 extract_terminal_pattern_display("npm install"),
107 Some("npm".to_string())
108 );
109 }
110
111 #[test]
112 fn test_extract_path_pattern() {
113 assert_eq!(
114 extract_path_pattern("/Users/alice/project/src/main.rs"),
115 Some("^/Users/alice/project/src/".to_string())
116 );
117 assert_eq!(
118 extract_path_pattern("src/lib.rs"),
119 Some("^src/".to_string())
120 );
121 assert_eq!(extract_path_pattern("file.txt"), None);
122 assert_eq!(extract_path_pattern("/file.txt"), None);
123 }
124
125 #[test]
126 fn test_extract_path_pattern_display() {
127 assert_eq!(
128 extract_path_pattern_display("/Users/alice/project/src/main.rs"),
129 Some("/Users/alice/project/src/".to_string())
130 );
131 assert_eq!(
132 extract_path_pattern_display("src/lib.rs"),
133 Some("src/".to_string())
134 );
135 }
136
137 #[test]
138 fn test_extract_url_pattern() {
139 assert_eq!(
140 extract_url_pattern("https://github.com/user/repo"),
141 Some("^https?://github\\.com".to_string())
142 );
143 assert_eq!(
144 extract_url_pattern("http://example.com/path?query=1"),
145 Some("^https?://example\\.com".to_string())
146 );
147 assert_eq!(extract_url_pattern("not a url"), None);
148 }
149
150 #[test]
151 fn test_extract_url_pattern_display() {
152 assert_eq!(
153 extract_url_pattern_display("https://github.com/user/repo"),
154 Some("github.com".to_string())
155 );
156 assert_eq!(
157 extract_url_pattern_display("http://api.example.com/v1/users"),
158 Some("api.example.com".to_string())
159 );
160 }
161
162 #[test]
163 fn test_special_chars_are_escaped() {
164 assert_eq!(
165 extract_path_pattern("/path/with (parens)/file.txt"),
166 Some("^/path/with \\(parens\\)/".to_string())
167 );
168 assert_eq!(
169 extract_url_pattern("https://test.example.com/path"),
170 Some("^https?://test\\.example\\.com".to_string())
171 );
172 }
173}