1mod agent_profile;
2
3use std::path::{Component, Path};
4use std::sync::{Arc, LazyLock};
5
6use agent_client_protocol::ModelId;
7use collections::{HashSet, IndexMap};
8use gpui::{App, Pixels, px};
9use language_model::LanguageModel;
10use project::DisableAiSettings;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use settings::{
14 DefaultAgentView, DockPosition, LanguageModelParameters, LanguageModelSelection,
15 NewThreadLocation, NotifyWhenAgentWaiting, RegisterSetting, Settings, SidebarDockPosition,
16 SidebarSide, ToolPermissionMode,
17};
18
19pub use crate::agent_profile::*;
20
21pub const SUMMARIZE_THREAD_PROMPT: &str = include_str!("prompts/summarize_thread_prompt.txt");
22pub const SUMMARIZE_THREAD_DETAILED_PROMPT: &str =
23 include_str!("prompts/summarize_thread_detailed_prompt.txt");
24
25#[derive(Clone, Debug, RegisterSetting)]
26pub struct AgentSettings {
27 pub enabled: bool,
28 pub button: bool,
29 pub dock: DockPosition,
30 pub sidebar_side: SidebarDockPosition,
31 pub default_width: Pixels,
32 pub default_height: Pixels,
33 pub default_model: Option<LanguageModelSelection>,
34 pub inline_assistant_model: Option<LanguageModelSelection>,
35 pub inline_assistant_use_streaming_tools: bool,
36 pub commit_message_model: Option<LanguageModelSelection>,
37 pub thread_summary_model: Option<LanguageModelSelection>,
38 pub inline_alternatives: Vec<LanguageModelSelection>,
39 pub favorite_models: Vec<LanguageModelSelection>,
40 pub default_profile: AgentProfileId,
41 pub default_view: DefaultAgentView,
42 pub profiles: IndexMap<AgentProfileId, AgentProfileSettings>,
43
44 pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
45 pub play_sound_when_agent_done: bool,
46 pub single_file_review: bool,
47 pub model_parameters: Vec<LanguageModelParameters>,
48 pub enable_feedback: bool,
49 pub expand_edit_card: bool,
50 pub expand_terminal_card: bool,
51 pub cancel_generation_on_terminal_stop: bool,
52 pub use_modifier_to_send: bool,
53 pub message_editor_min_lines: usize,
54 pub show_turn_stats: bool,
55 pub tool_permissions: ToolPermissions,
56 pub new_thread_location: NewThreadLocation,
57}
58
59impl AgentSettings {
60 pub fn enabled(&self, cx: &App) -> bool {
61 self.enabled && !DisableAiSettings::get_global(cx).disable_ai
62 }
63
64 pub fn temperature_for_model(model: &Arc<dyn LanguageModel>, cx: &App) -> Option<f32> {
65 let settings = Self::get_global(cx);
66 for setting in settings.model_parameters.iter().rev() {
67 if let Some(provider) = &setting.provider
68 && provider.0 != model.provider_id().0
69 {
70 continue;
71 }
72 if let Some(setting_model) = &setting.model
73 && *setting_model != model.id().0
74 {
75 continue;
76 }
77 return setting.temperature;
78 }
79 return None;
80 }
81
82 pub fn sidebar_side(&self) -> SidebarSide {
83 match self.sidebar_side {
84 SidebarDockPosition::Left => SidebarSide::Left,
85 SidebarDockPosition::Right => SidebarSide::Right,
86 SidebarDockPosition::FollowAgent => match self.dock {
87 DockPosition::Right => SidebarSide::Right,
88 _ => SidebarSide::Left,
89 },
90 }
91 }
92
93 pub fn set_message_editor_max_lines(&self) -> usize {
94 self.message_editor_min_lines * 2
95 }
96
97 pub fn favorite_model_ids(&self) -> HashSet<ModelId> {
98 self.favorite_models
99 .iter()
100 .map(|sel| ModelId::new(format!("{}/{}", sel.provider.0, sel.model)))
101 .collect()
102 }
103}
104
105#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)]
106pub struct AgentProfileId(pub Arc<str>);
107
108impl AgentProfileId {
109 pub fn as_str(&self) -> &str {
110 &self.0
111 }
112}
113
114impl std::fmt::Display for AgentProfileId {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(f, "{}", self.0)
117 }
118}
119
120impl Default for AgentProfileId {
121 fn default() -> Self {
122 Self("write".into())
123 }
124}
125
126#[derive(Clone, Debug, Default)]
127pub struct ToolPermissions {
128 /// Global default permission when no tool-specific rules or patterns match.
129 pub default: ToolPermissionMode,
130 pub tools: collections::HashMap<Arc<str>, ToolRules>,
131}
132
133impl ToolPermissions {
134 /// Returns all invalid regex patterns across all tools.
135 pub fn invalid_patterns(&self) -> Vec<&InvalidRegexPattern> {
136 self.tools
137 .values()
138 .flat_map(|rules| rules.invalid_patterns.iter())
139 .collect()
140 }
141
142 /// Returns true if any tool has invalid regex patterns.
143 pub fn has_invalid_patterns(&self) -> bool {
144 self.tools
145 .values()
146 .any(|rules| !rules.invalid_patterns.is_empty())
147 }
148}
149
150/// Represents a regex pattern that failed to compile.
151#[derive(Clone, Debug)]
152pub struct InvalidRegexPattern {
153 /// The pattern string that failed to compile.
154 pub pattern: String,
155 /// Which rule list this pattern was in (e.g., "always_deny", "always_allow", "always_confirm").
156 pub rule_type: String,
157 /// The error message from the regex compiler.
158 pub error: String,
159}
160
161#[derive(Clone, Debug, Default)]
162pub struct ToolRules {
163 pub default: Option<ToolPermissionMode>,
164 pub always_allow: Vec<CompiledRegex>,
165 pub always_deny: Vec<CompiledRegex>,
166 pub always_confirm: Vec<CompiledRegex>,
167 /// Patterns that failed to compile. If non-empty, tool calls should be blocked.
168 pub invalid_patterns: Vec<InvalidRegexPattern>,
169}
170
171#[derive(Clone)]
172pub struct CompiledRegex {
173 pub pattern: String,
174 pub case_sensitive: bool,
175 pub regex: regex::Regex,
176}
177
178impl std::fmt::Debug for CompiledRegex {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 f.debug_struct("CompiledRegex")
181 .field("pattern", &self.pattern)
182 .field("case_sensitive", &self.case_sensitive)
183 .finish()
184 }
185}
186
187impl CompiledRegex {
188 pub fn new(pattern: &str, case_sensitive: bool) -> Option<Self> {
189 Self::try_new(pattern, case_sensitive).ok()
190 }
191
192 pub fn try_new(pattern: &str, case_sensitive: bool) -> Result<Self, regex::Error> {
193 let regex = regex::RegexBuilder::new(pattern)
194 .case_insensitive(!case_sensitive)
195 .build()?;
196 Ok(Self {
197 pattern: pattern.to_string(),
198 case_sensitive,
199 regex,
200 })
201 }
202
203 pub fn is_match(&self, input: &str) -> bool {
204 self.regex.is_match(input)
205 }
206}
207
208pub const HARDCODED_SECURITY_DENIAL_MESSAGE: &str = "Blocked by built-in security rule. This operation is considered too \
209 harmful to be allowed, and cannot be overridden by settings.";
210
211/// Security rules that are always enforced and cannot be overridden by any setting.
212/// These protect against catastrophic operations like wiping filesystems.
213pub struct HardcodedSecurityRules {
214 pub terminal_deny: Vec<CompiledRegex>,
215}
216
217pub static HARDCODED_SECURITY_RULES: LazyLock<HardcodedSecurityRules> = LazyLock::new(|| {
218 const FLAGS: &str = r"(--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?\s+|-[a-zA-Z]+\s+)*";
219 const TRAILING_FLAGS: &str = r"(\s+--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?|\s+-[a-zA-Z]+)*\s*";
220
221 HardcodedSecurityRules {
222 terminal_deny: vec![
223 // Recursive deletion of root - "rm -rf /", "rm -rf /*"
224 CompiledRegex::new(
225 &format!(r"\brm\s+{FLAGS}(--\s+)?/\*?{TRAILING_FLAGS}$"),
226 false,
227 )
228 .expect("hardcoded regex should compile"),
229 // Recursive deletion of home via tilde - "rm -rf ~", "rm -rf ~/"
230 CompiledRegex::new(
231 &format!(r"\brm\s+{FLAGS}(--\s+)?~/?\*?{TRAILING_FLAGS}$"),
232 false,
233 )
234 .expect("hardcoded regex should compile"),
235 // Recursive deletion of home via env var - "rm -rf $HOME", "rm -rf ${HOME}"
236 CompiledRegex::new(
237 &format!(r"\brm\s+{FLAGS}(--\s+)?(\$HOME|\$\{{HOME\}})/?(\*)?{TRAILING_FLAGS}$"),
238 false,
239 )
240 .expect("hardcoded regex should compile"),
241 // Recursive deletion of current directory - "rm -rf .", "rm -rf ./"
242 CompiledRegex::new(
243 &format!(r"\brm\s+{FLAGS}(--\s+)?\./?\*?{TRAILING_FLAGS}$"),
244 false,
245 )
246 .expect("hardcoded regex should compile"),
247 // Recursive deletion of parent directory - "rm -rf ..", "rm -rf ../"
248 CompiledRegex::new(
249 &format!(r"\brm\s+{FLAGS}(--\s+)?\.\./?\*?{TRAILING_FLAGS}$"),
250 false,
251 )
252 .expect("hardcoded regex should compile"),
253 ],
254 }
255});
256
257/// Checks if input matches any hardcoded security rules that cannot be bypassed.
258/// Returns the denial reason string if blocked, None otherwise.
259///
260/// `terminal_tool_name` should be the tool name used for the terminal tool
261/// (e.g. `"terminal"`). `extracted_commands` can optionally provide parsed
262/// sub-commands for chained command checking; callers with access to a shell
263/// parser should extract sub-commands and pass them here.
264pub fn check_hardcoded_security_rules(
265 tool_name: &str,
266 terminal_tool_name: &str,
267 input: &str,
268 extracted_commands: Option<&[String]>,
269) -> Option<String> {
270 if tool_name != terminal_tool_name {
271 return None;
272 }
273
274 let rules = &*HARDCODED_SECURITY_RULES;
275 let terminal_patterns = &rules.terminal_deny;
276
277 if matches_hardcoded_patterns(input, terminal_patterns) {
278 return Some(HARDCODED_SECURITY_DENIAL_MESSAGE.into());
279 }
280
281 if let Some(commands) = extracted_commands {
282 for command in commands {
283 if matches_hardcoded_patterns(command, terminal_patterns) {
284 return Some(HARDCODED_SECURITY_DENIAL_MESSAGE.into());
285 }
286 }
287 }
288
289 None
290}
291
292fn matches_hardcoded_patterns(command: &str, patterns: &[CompiledRegex]) -> bool {
293 for pattern in patterns {
294 if pattern.is_match(command) {
295 return true;
296 }
297 }
298
299 for expanded in expand_rm_to_single_path_commands(command) {
300 for pattern in patterns {
301 if pattern.is_match(&expanded) {
302 return true;
303 }
304 }
305 }
306
307 false
308}
309
310fn expand_rm_to_single_path_commands(command: &str) -> Vec<String> {
311 let trimmed = command.trim();
312
313 let first_token = trimmed.split_whitespace().next();
314 if !first_token.is_some_and(|t| t.eq_ignore_ascii_case("rm")) {
315 return vec![];
316 }
317
318 let parts: Vec<&str> = trimmed.split_whitespace().collect();
319 let mut flags = Vec::new();
320 let mut paths = Vec::new();
321 let mut past_double_dash = false;
322
323 for part in parts.iter().skip(1) {
324 if !past_double_dash && *part == "--" {
325 past_double_dash = true;
326 flags.push(*part);
327 continue;
328 }
329 if !past_double_dash && part.starts_with('-') {
330 flags.push(*part);
331 } else {
332 paths.push(*part);
333 }
334 }
335
336 let flags_str = if flags.is_empty() {
337 String::new()
338 } else {
339 format!("{} ", flags.join(" "))
340 };
341
342 let mut results = Vec::new();
343 for path in &paths {
344 if path.starts_with('$') {
345 let home_prefix = if path.starts_with("${HOME}") {
346 Some("${HOME}")
347 } else if path.starts_with("$HOME") {
348 Some("$HOME")
349 } else {
350 None
351 };
352
353 if let Some(prefix) = home_prefix {
354 let suffix = &path[prefix.len()..];
355 if suffix.is_empty() {
356 results.push(format!("rm {flags_str}{path}"));
357 } else if suffix.starts_with('/') {
358 let normalized_suffix = normalize_path(suffix);
359 let reconstructed = if normalized_suffix == "/" {
360 prefix.to_string()
361 } else {
362 format!("{prefix}{normalized_suffix}")
363 };
364 results.push(format!("rm {flags_str}{reconstructed}"));
365 } else {
366 results.push(format!("rm {flags_str}{path}"));
367 }
368 } else {
369 results.push(format!("rm {flags_str}{path}"));
370 }
371 continue;
372 }
373
374 let mut normalized = normalize_path(path);
375 if normalized.is_empty() && !Path::new(path).has_root() {
376 normalized = ".".to_string();
377 }
378
379 results.push(format!("rm {flags_str}{normalized}"));
380 }
381
382 results
383}
384
385pub fn normalize_path(raw: &str) -> String {
386 let is_absolute = Path::new(raw).has_root();
387 let mut components: Vec<&str> = Vec::new();
388 for component in Path::new(raw).components() {
389 match component {
390 Component::CurDir => {}
391 Component::ParentDir => {
392 if components.last() == Some(&"..") {
393 components.push("..");
394 } else if !components.is_empty() {
395 components.pop();
396 } else if !is_absolute {
397 components.push("..");
398 }
399 }
400 Component::Normal(segment) => {
401 if let Some(s) = segment.to_str() {
402 components.push(s);
403 }
404 }
405 Component::RootDir | Component::Prefix(_) => {}
406 }
407 }
408 let joined = components.join("/");
409 if is_absolute {
410 format!("/{joined}")
411 } else {
412 joined
413 }
414}
415
416impl Settings for AgentSettings {
417 fn from_settings(content: &settings::SettingsContent) -> Self {
418 let agent = content.agent.clone().unwrap();
419 Self {
420 enabled: agent.enabled.unwrap(),
421 button: agent.button.unwrap(),
422 dock: agent.dock.unwrap(),
423 sidebar_side: agent.sidebar_side.unwrap(),
424 default_width: px(agent.default_width.unwrap()),
425 default_height: px(agent.default_height.unwrap()),
426 default_model: Some(agent.default_model.unwrap()),
427 inline_assistant_model: agent.inline_assistant_model,
428 inline_assistant_use_streaming_tools: agent
429 .inline_assistant_use_streaming_tools
430 .unwrap_or(true),
431 commit_message_model: agent.commit_message_model,
432 thread_summary_model: agent.thread_summary_model,
433 inline_alternatives: agent.inline_alternatives.unwrap_or_default(),
434 favorite_models: agent.favorite_models,
435 default_profile: AgentProfileId(agent.default_profile.unwrap()),
436 default_view: agent.default_view.unwrap(),
437 profiles: agent
438 .profiles
439 .unwrap()
440 .into_iter()
441 .map(|(key, val)| (AgentProfileId(key), val.into()))
442 .collect(),
443
444 notify_when_agent_waiting: agent.notify_when_agent_waiting.unwrap(),
445 play_sound_when_agent_done: agent.play_sound_when_agent_done.unwrap(),
446 single_file_review: agent.single_file_review.unwrap(),
447 model_parameters: agent.model_parameters,
448 enable_feedback: agent.enable_feedback.unwrap(),
449 expand_edit_card: agent.expand_edit_card.unwrap(),
450 expand_terminal_card: agent.expand_terminal_card.unwrap(),
451 cancel_generation_on_terminal_stop: agent.cancel_generation_on_terminal_stop.unwrap(),
452 use_modifier_to_send: agent.use_modifier_to_send.unwrap(),
453 message_editor_min_lines: agent.message_editor_min_lines.unwrap(),
454 show_turn_stats: agent.show_turn_stats.unwrap(),
455 tool_permissions: compile_tool_permissions(agent.tool_permissions),
456 new_thread_location: agent.new_thread_location.unwrap_or_default(),
457 }
458 }
459}
460
461fn compile_tool_permissions(content: Option<settings::ToolPermissionsContent>) -> ToolPermissions {
462 let Some(content) = content else {
463 return ToolPermissions::default();
464 };
465
466 let tools = content
467 .tools
468 .into_iter()
469 .map(|(tool_name, rules_content)| {
470 let mut invalid_patterns = Vec::new();
471
472 let (always_allow, allow_errors) = compile_regex_rules(
473 rules_content.always_allow.map(|v| v.0).unwrap_or_default(),
474 "always_allow",
475 );
476 invalid_patterns.extend(allow_errors);
477
478 let (always_deny, deny_errors) = compile_regex_rules(
479 rules_content.always_deny.map(|v| v.0).unwrap_or_default(),
480 "always_deny",
481 );
482 invalid_patterns.extend(deny_errors);
483
484 let (always_confirm, confirm_errors) = compile_regex_rules(
485 rules_content
486 .always_confirm
487 .map(|v| v.0)
488 .unwrap_or_default(),
489 "always_confirm",
490 );
491 invalid_patterns.extend(confirm_errors);
492
493 // Log invalid patterns for debugging. Users will see an error when they
494 // attempt to use a tool with invalid patterns in their settings.
495 for invalid in &invalid_patterns {
496 log::error!(
497 "Invalid regex pattern in tool_permissions for '{}' tool ({}): '{}' - {}",
498 tool_name,
499 invalid.rule_type,
500 invalid.pattern,
501 invalid.error,
502 );
503 }
504
505 let rules = ToolRules {
506 // Preserve tool-specific default; None means fall back to global default at decision time
507 default: rules_content.default,
508 always_allow,
509 always_deny,
510 always_confirm,
511 invalid_patterns,
512 };
513 (tool_name, rules)
514 })
515 .collect();
516
517 ToolPermissions {
518 default: content.default.unwrap_or_default(),
519 tools,
520 }
521}
522
523fn compile_regex_rules(
524 rules: Vec<settings::ToolRegexRule>,
525 rule_type: &str,
526) -> (Vec<CompiledRegex>, Vec<InvalidRegexPattern>) {
527 let mut compiled = Vec::new();
528 let mut errors = Vec::new();
529
530 for rule in rules {
531 if rule.pattern.is_empty() {
532 errors.push(InvalidRegexPattern {
533 pattern: rule.pattern,
534 rule_type: rule_type.to_string(),
535 error: "empty regex patterns are not allowed".to_string(),
536 });
537 continue;
538 }
539 let case_sensitive = rule.case_sensitive.unwrap_or(false);
540 match CompiledRegex::try_new(&rule.pattern, case_sensitive) {
541 Ok(regex) => compiled.push(regex),
542 Err(error) => {
543 errors.push(InvalidRegexPattern {
544 pattern: rule.pattern,
545 rule_type: rule_type.to_string(),
546 error: error.to_string(),
547 });
548 }
549 }
550 }
551
552 (compiled, errors)
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use serde_json::json;
559 use settings::ToolPermissionMode;
560 use settings::ToolPermissionsContent;
561
562 #[test]
563 fn test_compiled_regex_case_insensitive() {
564 let regex = CompiledRegex::new("rm\\s+-rf", false).unwrap();
565 assert!(regex.is_match("rm -rf /"));
566 assert!(regex.is_match("RM -RF /"));
567 assert!(regex.is_match("Rm -Rf /"));
568 }
569
570 #[test]
571 fn test_compiled_regex_case_sensitive() {
572 let regex = CompiledRegex::new("DROP\\s+TABLE", true).unwrap();
573 assert!(regex.is_match("DROP TABLE users"));
574 assert!(!regex.is_match("drop table users"));
575 }
576
577 #[test]
578 fn test_invalid_regex_returns_none() {
579 let result = CompiledRegex::new("[invalid(regex", false);
580 assert!(result.is_none());
581 }
582
583 #[test]
584 fn test_tool_permissions_parsing() {
585 let json = json!({
586 "tools": {
587 "terminal": {
588 "default": "allow",
589 "always_deny": [
590 { "pattern": "rm\\s+-rf" }
591 ],
592 "always_allow": [
593 { "pattern": "^git\\s" }
594 ]
595 }
596 }
597 });
598
599 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
600 let permissions = compile_tool_permissions(Some(content));
601
602 let terminal_rules = permissions.tools.get("terminal").unwrap();
603 assert_eq!(terminal_rules.default, Some(ToolPermissionMode::Allow));
604 assert_eq!(terminal_rules.always_deny.len(), 1);
605 assert_eq!(terminal_rules.always_allow.len(), 1);
606 assert!(terminal_rules.always_deny[0].is_match("rm -rf /"));
607 assert!(terminal_rules.always_allow[0].is_match("git status"));
608 }
609
610 #[test]
611 fn test_tool_rules_default() {
612 let json = json!({
613 "tools": {
614 "edit_file": {
615 "default": "deny"
616 }
617 }
618 });
619
620 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
621 let permissions = compile_tool_permissions(Some(content));
622
623 let rules = permissions.tools.get("edit_file").unwrap();
624 assert_eq!(rules.default, Some(ToolPermissionMode::Deny));
625 }
626
627 #[test]
628 fn test_tool_permissions_empty() {
629 let permissions = compile_tool_permissions(None);
630 assert!(permissions.tools.is_empty());
631 assert_eq!(permissions.default, ToolPermissionMode::Confirm);
632 }
633
634 #[test]
635 fn test_tool_rules_default_returns_confirm() {
636 let default_rules = ToolRules::default();
637 assert_eq!(default_rules.default, None);
638 assert!(default_rules.always_allow.is_empty());
639 assert!(default_rules.always_deny.is_empty());
640 assert!(default_rules.always_confirm.is_empty());
641 }
642
643 #[test]
644 fn test_tool_permissions_with_multiple_tools() {
645 let json = json!({
646 "tools": {
647 "terminal": {
648 "default": "allow",
649 "always_deny": [{ "pattern": "rm\\s+-rf" }]
650 },
651 "edit_file": {
652 "default": "confirm",
653 "always_deny": [{ "pattern": "\\.env$" }]
654 },
655 "delete_path": {
656 "default": "deny"
657 }
658 }
659 });
660
661 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
662 let permissions = compile_tool_permissions(Some(content));
663
664 assert_eq!(permissions.tools.len(), 3);
665
666 let terminal = permissions.tools.get("terminal").unwrap();
667 assert_eq!(terminal.default, Some(ToolPermissionMode::Allow));
668 assert_eq!(terminal.always_deny.len(), 1);
669
670 let edit_file = permissions.tools.get("edit_file").unwrap();
671 assert_eq!(edit_file.default, Some(ToolPermissionMode::Confirm));
672 assert!(edit_file.always_deny[0].is_match("secrets.env"));
673
674 let delete_path = permissions.tools.get("delete_path").unwrap();
675 assert_eq!(delete_path.default, Some(ToolPermissionMode::Deny));
676 }
677
678 #[test]
679 fn test_tool_permissions_with_all_rule_types() {
680 let json = json!({
681 "tools": {
682 "terminal": {
683 "always_deny": [{ "pattern": "rm\\s+-rf" }],
684 "always_confirm": [{ "pattern": "sudo\\s" }],
685 "always_allow": [{ "pattern": "^git\\s+status" }]
686 }
687 }
688 });
689
690 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
691 let permissions = compile_tool_permissions(Some(content));
692
693 let terminal = permissions.tools.get("terminal").unwrap();
694 assert_eq!(terminal.always_deny.len(), 1);
695 assert_eq!(terminal.always_confirm.len(), 1);
696 assert_eq!(terminal.always_allow.len(), 1);
697
698 assert!(terminal.always_deny[0].is_match("rm -rf /"));
699 assert!(terminal.always_confirm[0].is_match("sudo apt install"));
700 assert!(terminal.always_allow[0].is_match("git status"));
701 }
702
703 #[test]
704 fn test_invalid_regex_is_tracked_and_valid_ones_still_compile() {
705 let json = json!({
706 "tools": {
707 "terminal": {
708 "always_deny": [
709 { "pattern": "[invalid(regex" },
710 { "pattern": "valid_pattern" }
711 ],
712 "always_allow": [
713 { "pattern": "[another_bad" }
714 ]
715 }
716 }
717 });
718
719 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
720 let permissions = compile_tool_permissions(Some(content));
721
722 let terminal = permissions.tools.get("terminal").unwrap();
723
724 // Valid patterns should still be compiled
725 assert_eq!(terminal.always_deny.len(), 1);
726 assert!(terminal.always_deny[0].is_match("valid_pattern"));
727
728 // Invalid patterns should be tracked (order depends on processing order)
729 assert_eq!(terminal.invalid_patterns.len(), 2);
730
731 let deny_invalid = terminal
732 .invalid_patterns
733 .iter()
734 .find(|p| p.rule_type == "always_deny")
735 .expect("should have invalid pattern from always_deny");
736 assert_eq!(deny_invalid.pattern, "[invalid(regex");
737 assert!(!deny_invalid.error.is_empty());
738
739 let allow_invalid = terminal
740 .invalid_patterns
741 .iter()
742 .find(|p| p.rule_type == "always_allow")
743 .expect("should have invalid pattern from always_allow");
744 assert_eq!(allow_invalid.pattern, "[another_bad");
745
746 // ToolPermissions helper methods should work
747 assert!(permissions.has_invalid_patterns());
748 assert_eq!(permissions.invalid_patterns().len(), 2);
749 }
750
751 #[test]
752 fn test_deny_takes_precedence_over_allow_and_confirm() {
753 let json = json!({
754 "tools": {
755 "terminal": {
756 "default": "allow",
757 "always_deny": [{ "pattern": "dangerous" }],
758 "always_confirm": [{ "pattern": "dangerous" }],
759 "always_allow": [{ "pattern": "dangerous" }]
760 }
761 }
762 });
763
764 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
765 let permissions = compile_tool_permissions(Some(content));
766 let terminal = permissions.tools.get("terminal").unwrap();
767
768 assert!(
769 terminal.always_deny[0].is_match("run dangerous command"),
770 "Deny rule should match"
771 );
772 assert!(
773 terminal.always_allow[0].is_match("run dangerous command"),
774 "Allow rule should also match (but deny takes precedence at evaluation time)"
775 );
776 assert!(
777 terminal.always_confirm[0].is_match("run dangerous command"),
778 "Confirm rule should also match (but deny takes precedence at evaluation time)"
779 );
780 }
781
782 #[test]
783 fn test_confirm_takes_precedence_over_allow() {
784 let json = json!({
785 "tools": {
786 "terminal": {
787 "default": "allow",
788 "always_confirm": [{ "pattern": "risky" }],
789 "always_allow": [{ "pattern": "risky" }]
790 }
791 }
792 });
793
794 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
795 let permissions = compile_tool_permissions(Some(content));
796 let terminal = permissions.tools.get("terminal").unwrap();
797
798 assert!(
799 terminal.always_confirm[0].is_match("do risky thing"),
800 "Confirm rule should match"
801 );
802 assert!(
803 terminal.always_allow[0].is_match("do risky thing"),
804 "Allow rule should also match (but confirm takes precedence at evaluation time)"
805 );
806 }
807
808 #[test]
809 fn test_regex_matches_anywhere_in_string_not_just_anchored() {
810 let json = json!({
811 "tools": {
812 "terminal": {
813 "always_deny": [
814 { "pattern": "rm\\s+-rf" },
815 { "pattern": "/etc/passwd" }
816 ]
817 }
818 }
819 });
820
821 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
822 let permissions = compile_tool_permissions(Some(content));
823 let terminal = permissions.tools.get("terminal").unwrap();
824
825 assert!(
826 terminal.always_deny[0].is_match("echo hello && rm -rf /"),
827 "Should match rm -rf in the middle of a command chain"
828 );
829 assert!(
830 terminal.always_deny[0].is_match("cd /tmp; rm -rf *"),
831 "Should match rm -rf after semicolon"
832 );
833 assert!(
834 terminal.always_deny[1].is_match("cat /etc/passwd | grep root"),
835 "Should match /etc/passwd in a pipeline"
836 );
837 assert!(
838 terminal.always_deny[1].is_match("vim /etc/passwd"),
839 "Should match /etc/passwd as argument"
840 );
841 }
842
843 #[test]
844 fn test_fork_bomb_pattern_matches() {
845 let fork_bomb_regex = CompiledRegex::new(r":\(\)\{\s*:\|:&\s*\};:", false).unwrap();
846 assert!(
847 fork_bomb_regex.is_match(":(){ :|:& };:"),
848 "Should match the classic fork bomb"
849 );
850 assert!(
851 fork_bomb_regex.is_match(":(){ :|:&};:"),
852 "Should match fork bomb without spaces"
853 );
854 }
855
856 #[test]
857 fn test_compiled_regex_stores_case_sensitivity() {
858 let case_sensitive = CompiledRegex::new("test", true).unwrap();
859 let case_insensitive = CompiledRegex::new("test", false).unwrap();
860
861 assert!(case_sensitive.case_sensitive);
862 assert!(!case_insensitive.case_sensitive);
863 }
864
865 #[test]
866 fn test_invalid_regex_is_skipped_not_fail() {
867 let json = json!({
868 "tools": {
869 "terminal": {
870 "always_deny": [
871 { "pattern": "[invalid(regex" },
872 { "pattern": "valid_pattern" }
873 ]
874 }
875 }
876 });
877
878 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
879 let permissions = compile_tool_permissions(Some(content));
880
881 let terminal = permissions.tools.get("terminal").unwrap();
882 assert_eq!(terminal.always_deny.len(), 1);
883 assert!(terminal.always_deny[0].is_match("valid_pattern"));
884 }
885
886 #[test]
887 fn test_unconfigured_tool_not_in_permissions() {
888 let json = json!({
889 "tools": {
890 "terminal": {
891 "default": "allow"
892 }
893 }
894 });
895
896 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
897 let permissions = compile_tool_permissions(Some(content));
898
899 assert!(permissions.tools.contains_key("terminal"));
900 assert!(!permissions.tools.contains_key("edit_file"));
901 assert!(!permissions.tools.contains_key("fetch"));
902 }
903
904 #[test]
905 fn test_always_allow_pattern_only_matches_specified_commands() {
906 // Reproduces user-reported bug: when always_allow has pattern "^echo\s",
907 // only "echo hello" should be allowed, not "git status".
908 //
909 // User config:
910 // always_allow_tool_actions: false
911 // tool_permissions.tools.terminal.always_allow: [{ pattern: "^echo\\s" }]
912 let json = json!({
913 "tools": {
914 "terminal": {
915 "always_allow": [
916 { "pattern": "^echo\\s" }
917 ]
918 }
919 }
920 });
921
922 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
923 let permissions = compile_tool_permissions(Some(content));
924
925 let terminal = permissions.tools.get("terminal").unwrap();
926
927 // Verify the pattern was compiled
928 assert_eq!(
929 terminal.always_allow.len(),
930 1,
931 "Should have one always_allow pattern"
932 );
933
934 // Verify the pattern matches "echo hello"
935 assert!(
936 terminal.always_allow[0].is_match("echo hello"),
937 "Pattern ^echo\\s should match 'echo hello'"
938 );
939
940 // Verify the pattern does NOT match "git status"
941 assert!(
942 !terminal.always_allow[0].is_match("git status"),
943 "Pattern ^echo\\s should NOT match 'git status'"
944 );
945
946 // Verify the pattern does NOT match "echoHello" (no space)
947 assert!(
948 !terminal.always_allow[0].is_match("echoHello"),
949 "Pattern ^echo\\s should NOT match 'echoHello' (requires whitespace)"
950 );
951
952 assert_eq!(
953 terminal.default, None,
954 "default should be None when not specified"
955 );
956 }
957
958 #[test]
959 fn test_empty_regex_pattern_is_invalid() {
960 let json = json!({
961 "tools": {
962 "terminal": {
963 "always_allow": [
964 { "pattern": "" }
965 ],
966 "always_deny": [
967 { "case_sensitive": true }
968 ],
969 "always_confirm": [
970 { "pattern": "" },
971 { "pattern": "valid_pattern" }
972 ]
973 }
974 }
975 });
976
977 let content: ToolPermissionsContent = serde_json::from_value(json).unwrap();
978 let permissions = compile_tool_permissions(Some(content));
979
980 let terminal = permissions.tools.get("terminal").unwrap();
981
982 assert_eq!(terminal.always_allow.len(), 0);
983 assert_eq!(terminal.always_deny.len(), 0);
984 assert_eq!(terminal.always_confirm.len(), 1);
985 assert!(terminal.always_confirm[0].is_match("valid_pattern"));
986
987 assert_eq!(terminal.invalid_patterns.len(), 3);
988 for invalid in &terminal.invalid_patterns {
989 assert_eq!(invalid.pattern, "");
990 assert!(invalid.error.contains("empty"));
991 }
992 }
993
994 #[test]
995 fn test_default_json_tool_permissions_parse() {
996 let default_json = include_str!("../../../assets/settings/default.json");
997 let value: serde_json_lenient::Value = serde_json_lenient::from_str(default_json).unwrap();
998 let agent = value
999 .get("agent")
1000 .expect("default.json should have 'agent' key");
1001 let tool_permissions_value = agent
1002 .get("tool_permissions")
1003 .expect("agent should have 'tool_permissions' key");
1004
1005 let content: ToolPermissionsContent =
1006 serde_json_lenient::from_value(tool_permissions_value.clone()).unwrap();
1007 let permissions = compile_tool_permissions(Some(content));
1008
1009 assert_eq!(permissions.default, ToolPermissionMode::Confirm);
1010
1011 assert!(
1012 permissions.tools.is_empty(),
1013 "default.json should not have any active tool-specific rules, found: {:?}",
1014 permissions.tools.keys().collect::<Vec<_>>()
1015 );
1016 }
1017
1018 #[test]
1019 fn test_tool_permissions_explicit_global_default() {
1020 let json_allow = json!({
1021 "default": "allow"
1022 });
1023 let content: ToolPermissionsContent = serde_json::from_value(json_allow).unwrap();
1024 let permissions = compile_tool_permissions(Some(content));
1025 assert_eq!(permissions.default, ToolPermissionMode::Allow);
1026
1027 let json_deny = json!({
1028 "default": "deny"
1029 });
1030 let content: ToolPermissionsContent = serde_json::from_value(json_deny).unwrap();
1031 let permissions = compile_tool_permissions(Some(content));
1032 assert_eq!(permissions.default, ToolPermissionMode::Deny);
1033 }
1034}