agent: Include grep-related instructions in the prompt only if the tool is available (#29536)

Oleksiy Syvokon created

This change updates the system prompt to conditionally include
`grep`-related instructions based on whether the `grep` tool is enabled.

Implementation details:
1. Add a `has_tool` handlebars helper.
2. Pass the `model` to all locations where the prompt is built.
3. Use `{{#if has_tool "grep"}}` in the system prompt to gate
`grep`-specific instructions.

Testing:
- Unit tests for the `hasTool` helper.
- Unit tests to verify that `grep`-related instructions are included /
omitted from the prompt as appropriate.
- Manual agent evaluation:
- Setup: Asked the Agent "List all impls of MyTrait in the project"
using a custom "No tools" profile (all tools disabled).
- Before the change: The Agent attempted to call `grep`, encountered an
error, then realized the tool was unavailable.
- After the change: The Agent immediately asked to enable a search tool.

Note: in principle, `grep`/`read_file` tool descriptions alone might be
enough, but to confirm this we need more evaluation. If it turns out to
be true, we'll be able to remove grep-specific instructions from the
system prompt and undo this change.

Release Notes:

- N/A

Change summary

Cargo.lock                                 |   1 
assets/prompts/assistant_system_prompt.hbs |   7 
crates/agent/src/thread.rs                 | 124 ++++++++++++++++-------
crates/prompt_store/Cargo.toml             |   1 
crates/prompt_store/src/prompts.rs         | 122 +++++++++++++++++++++++
5 files changed, 209 insertions(+), 46 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -11114,6 +11114,7 @@ dependencies = [
  "paths",
  "rope",
  "serde",
+ "serde_json",
  "text",
  "util",
  "uuid",

assets/prompts/assistant_system_prompt.hbs 🔗

@@ -27,13 +27,14 @@ If appropriate, use tool calls to explore the current project, which contains th
 - `{{root_name}}`
 {{/each}}
 
+- Bias towards not asking the user for help if you can find the answer yourself.
 - When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
+- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
+{{# if (has_tool 'grep') }}
 - When looking for symbols in the project, prefer the `grep` tool.
 - As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
-- Bias towards not asking the user for help if you can find the answer yourself.
-{{! TODO: Only mention tools if they are enabled }}
 - The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
-- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
+{{/if}}
 
 ## Fixing Diagnostics
 

crates/agent/src/thread.rs 🔗

@@ -26,7 +26,7 @@ use language_model::{
 };
 use project::Project;
 use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
-use prompt_store::PromptBuilder;
+use prompt_store::{ModelContext, PromptBuilder};
 use proto::Plan;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -740,6 +740,32 @@ impl Thread {
         self.tool_use.tool_result_card(id).cloned()
     }
 
+    /// Return tools that are both enabled and supported by the model
+    pub fn available_tools(
+        &self,
+        cx: &App,
+        model: Arc<dyn LanguageModel>,
+    ) -> Vec<LanguageModelRequestTool> {
+        if model.supports_tools() {
+            self.tools()
+                .read(cx)
+                .enabled_tools(cx)
+                .into_iter()
+                .filter_map(|tool| {
+                    // Skip tools that cannot be supported
+                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
+                    Some(LanguageModelRequestTool {
+                        name: tool.name(),
+                        description: tool.description(),
+                        input_schema,
+                    })
+                })
+                .collect()
+        } else {
+            Vec::default()
+        }
+    }
+
     pub fn insert_user_message(
         &mut self,
         text: impl Into<String>,
@@ -941,30 +967,7 @@ impl Thread {
 
         self.remaining_turns -= 1;
 
-        let mut request = self.to_completion_request(cx);
-        request.mode = if model.supports_max_mode() {
-            self.completion_mode
-        } else {
-            None
-        };
-
-        if model.supports_tools() {
-            request.tools = self
-                .tools()
-                .read(cx)
-                .enabled_tools(cx)
-                .into_iter()
-                .filter_map(|tool| {
-                    // Skip tools that cannot be supported
-                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
-                    Some(LanguageModelRequestTool {
-                        name: tool.name(),
-                        description: tool.description(),
-                        input_schema,
-                    })
-                })
-                .collect();
-        }
+        let request = self.to_completion_request(model.clone(), cx);
 
         self.stream_completion(request, model, window, cx);
     }
@@ -981,7 +984,11 @@ impl Thread {
         false
     }
 
-    pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
+    pub fn to_completion_request(
+        &self,
+        model: Arc<dyn LanguageModel>,
+        cx: &mut Context<Self>,
+    ) -> LanguageModelRequest {
         let mut request = LanguageModelRequest {
             thread_id: Some(self.id.to_string()),
             prompt_id: Some(self.last_prompt_id.to_string()),
@@ -992,10 +999,20 @@ impl Thread {
             temperature: None,
         };
 
+        let available_tools = self.available_tools(cx, model.clone());
+        let available_tool_names = available_tools
+            .iter()
+            .map(|tool| tool.name.clone())
+            .collect();
+
+        let model_context = &ModelContext {
+            available_tools: available_tool_names,
+        };
+
         if let Some(project_context) = self.project_context.borrow().as_ref() {
             match self
                 .prompt_builder
-                .generate_assistant_system_prompt(project_context)
+                .generate_assistant_system_prompt(project_context, model_context)
             {
                 Err(err) => {
                     let message = format!("{err:?}").into();
@@ -1075,6 +1092,13 @@ impl Thread {
 
         self.attached_tracked_files_state(&mut request.messages, cx);
 
+        request.tools = available_tools;
+        request.mode = if model.supports_max_mode() {
+            self.completion_mode
+        } else {
+            None
+        };
+
         request
     }
 
@@ -1376,7 +1400,7 @@ impl Thread {
                     match result.as_ref() {
                         Ok(stop_reason) => match stop_reason {
                             StopReason::ToolUse => {
-                                let tool_uses = thread.use_pending_tools(window, cx);
+                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
                                 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
                             }
                             StopReason::EndTurn => {}
@@ -1594,9 +1618,10 @@ impl Thread {
         &mut self,
         window: Option<AnyWindowHandle>,
         cx: &mut Context<Self>,
+        model: Arc<dyn LanguageModel>,
     ) -> Vec<PendingToolUse> {
         self.auto_capture_telemetry(cx);
-        let request = self.to_completion_request(cx);
+        let request = self.to_completion_request(model, cx);
         let messages = Arc::new(request.messages);
         let pending_tool_uses = self
             .tool_use
@@ -2316,9 +2341,11 @@ mod tests {
     use super::*;
     use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
     use assistant_settings::AssistantSettings;
+    use assistant_tool::ToolRegistry;
     use context_server::ContextServerSettings;
     use editor::EditorSettings;
     use gpui::TestAppContext;
+    use language_model::fake_provider::FakeLanguageModel;
     use project::{FakeFs, Project};
     use prompt_store::PromptBuilder;
     use serde_json::json;
@@ -2338,7 +2365,7 @@ mod tests {
         )
         .await;
 
-        let (_workspace, _thread_store, thread, context_store) =
+        let (_workspace, _thread_store, thread, context_store, model) =
             setup_test_environment(cx, project.clone()).await;
 
         add_file_to_context(&project, &context_store, "test/code.rs", cx)
@@ -2389,7 +2416,9 @@ fn main() {{
         assert_eq!(message.loaded_context.text, expected_context);
 
         // Check message in request
-        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
+        let request = thread.update(cx, |thread, cx| {
+            thread.to_completion_request(model.clone(), cx)
+        });
 
         assert_eq!(request.messages.len(), 2);
         let expected_full_message = format!("{}Please explain this code", expected_context);
@@ -2410,7 +2439,7 @@ fn main() {{
         )
         .await;
 
-        let (_, _thread_store, thread, context_store) =
+        let (_, _thread_store, thread, context_store, model) =
             setup_test_environment(cx, project.clone()).await;
 
         // First message with context 1
@@ -2481,7 +2510,9 @@ fn main() {{
         assert!(message3.loaded_context.text.contains("file3.rs"));
 
         // Check entire request to make sure all contexts are properly included
-        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
+        let request = thread.update(cx, |thread, cx| {
+            thread.to_completion_request(model.clone(), cx)
+        });
 
         // The request should contain all 3 messages
         assert_eq!(request.messages.len(), 4);
@@ -2510,7 +2541,7 @@ fn main() {{
         )
         .await;
 
-        let (_, _thread_store, thread, _context_store) =
+        let (_, _thread_store, thread, _context_store, model) =
             setup_test_environment(cx, project.clone()).await;
 
         // Insert user message without any context (empty context vector)
@@ -2536,7 +2567,9 @@ fn main() {{
         assert_eq!(message.loaded_context.text, "");
 
         // Check message in request
-        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
+        let request = thread.update(cx, |thread, cx| {
+            thread.to_completion_request(model.clone(), cx)
+        });
 
         assert_eq!(request.messages.len(), 2);
         assert_eq!(
@@ -2559,7 +2592,9 @@ fn main() {{
         assert_eq!(message2.loaded_context.text, "");
 
         // Check that both messages appear in the request
-        let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
+        let request = thread.update(cx, |thread, cx| {
+            thread.to_completion_request(model.clone(), cx)
+        });
 
         assert_eq!(request.messages.len(), 3);
         assert_eq!(
@@ -2582,7 +2617,7 @@ fn main() {{
         )
         .await;
 
-        let (_workspace, _thread_store, thread, context_store) =
+        let (_workspace, _thread_store, thread, context_store, model) =
             setup_test_environment(cx, project.clone()).await;
 
         // Open buffer and add it to context
@@ -2601,7 +2636,9 @@ fn main() {{
         });
 
         // Create a request and check that it doesn't have a stale buffer warning yet
-        let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
+        let initial_request = thread.update(cx, |thread, cx| {
+            thread.to_completion_request(model.clone(), cx)
+        });
 
         // Make sure we don't have a stale file warning yet
         let has_stale_warning = initial_request.messages.iter().any(|msg| {
@@ -2634,7 +2671,9 @@ fn main() {{
         });
 
         // Create a new request and check for the stale buffer warning
-        let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
+        let new_request = thread.update(cx, |thread, cx| {
+            thread.to_completion_request(model.clone(), cx)
+        });
 
         // We should have a stale file warning as the last message
         let last_message = new_request
@@ -2667,6 +2706,7 @@ fn main() {{
             ThemeSettings::register(cx);
             ContextServerSettings::register(cx);
             EditorSettings::register(cx);
+            ToolRegistry::default_global(cx);
         });
     }
 
@@ -2688,6 +2728,7 @@ fn main() {{
         Entity<ThreadStore>,
         Entity<Thread>,
         Entity<ContextStore>,
+        Arc<dyn LanguageModel>,
     ) {
         let (workspace, cx) =
             cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
@@ -2708,7 +2749,10 @@ fn main() {{
         let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
         let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
 
-        (workspace, thread_store, thread, context_store)
+        let model = FakeLanguageModel::default();
+        let model: Arc<dyn LanguageModel> = Arc::new(model);
+
+        (workspace, thread_store, thread, context_store, model)
     }
 
     async fn add_file_to_context(

crates/prompt_store/Cargo.toml 🔗

@@ -28,6 +28,7 @@ parking_lot.workspace = true
 paths.workspace = true
 rope.workspace = true
 serde.workspace = true
+serde_json.workspace = true
 text.workspace = true
 util.workspace = true
 uuid.workspace = true

crates/prompt_store/src/prompts.rs 🔗

@@ -48,6 +48,20 @@ impl ProjectContext {
     }
 }
 
+#[derive(Debug, Clone, Serialize)]
+pub struct ModelContext {
+    pub available_tools: Vec<String>,
+}
+
+#[derive(Serialize)]
+struct PromptTemplateContext {
+    #[serde(flatten)]
+    project: ProjectContext,
+
+    #[serde(flatten)]
+    model: ModelContext,
+}
+
 #[derive(Debug, Clone, Serialize)]
 pub struct UserRulesContext {
     pub uuid: UserPromptId,
@@ -124,9 +138,40 @@ impl PromptBuilder {
         .unwrap_or_else(|| Arc::new(Self::new(None).unwrap()))
     }
 
+    /// Helper function for handlebars templates to check if a specific tool is enabled
+    fn has_tool_helper(
+        h: &handlebars::Helper,
+        _: &Handlebars,
+        ctx: &handlebars::Context,
+        _: &mut handlebars::RenderContext,
+        out: &mut dyn handlebars::Output,
+    ) -> handlebars::HelperResult {
+        let tool_name = h.param(0).and_then(|v| v.value().as_str()).ok_or_else(|| {
+            handlebars::RenderError::new("has_tool helper: missing or invalid tool name parameter")
+        })?;
+
+        let enabled_tools = ctx
+            .data()
+            .get("available_tools")
+            .and_then(|v| v.as_array())
+            .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect::<Vec<&str>>())
+            .ok_or_else(|| {
+                handlebars::RenderError::new(
+                    "has_tool handlebars helper: available_tools not found or not an array",
+                )
+            })?;
+
+        if enabled_tools.contains(&tool_name) {
+            out.write("true")?;
+        }
+
+        Ok(())
+    }
+
     pub fn new(loading_params: Option<PromptLoadingParams>) -> Result<Self> {
         let mut handlebars = Handlebars::new();
         Self::register_built_in_templates(&mut handlebars)?;
+        handlebars.register_helper("has_tool", Box::new(Self::has_tool_helper));
 
         let handlebars = Arc::new(Mutex::new(handlebars));
 
@@ -278,10 +323,16 @@ impl PromptBuilder {
     pub fn generate_assistant_system_prompt(
         &self,
         context: &ProjectContext,
+        model_context: &ModelContext,
     ) -> Result<String, RenderError> {
+        let template_context = PromptTemplateContext {
+            project: context.clone(),
+            model: model_context.clone(),
+        };
+
         self.handlebars
             .lock()
-            .render("assistant_system_prompt", context)
+            .render("assistant_system_prompt", &template_context)
     }
 
     pub fn generate_inline_transformation_prompt(
@@ -398,6 +449,7 @@ impl PromptBuilder {
 #[cfg(test)]
 mod test {
     use super::*;
+    use serde_json;
     use uuid::Uuid;
 
     #[test]
@@ -416,9 +468,73 @@ mod test {
             contents: "Rules contents".into(),
         }];
         let project_context = ProjectContext::new(worktrees, default_user_rules);
-        PromptBuilder::new(None)
+        let model_context = ModelContext {
+            available_tools: ["grep".into()].to_vec(),
+        };
+        let prompt = PromptBuilder::new(None)
             .unwrap()
-            .generate_assistant_system_prompt(&project_context)
+            .generate_assistant_system_prompt(&project_context, &model_context)
+            .unwrap();
+        assert!(
+            prompt.contains("Rules contents"),
+            "Expected default user rules to be in rendered prompt"
+        );
+    }
+
+    #[test]
+    fn test_assistant_system_prompt_depends_on_enabled_tools() {
+        let worktrees = vec![WorktreeContext {
+            root_name: "path".into(),
+            rules_file: None,
+        }];
+        let default_user_rules = vec![];
+        let project_context = ProjectContext::new(worktrees, default_user_rules);
+        let prompt_builder = PromptBuilder::new(None).unwrap();
+
+        // When the `grep` tool is enabled, it should be mentioned in the prompt
+        let model_context = ModelContext {
+            available_tools: ["grep".into()].to_vec(),
+        };
+        let prompt_with_grep = prompt_builder
+            .generate_assistant_system_prompt(&project_context, &model_context)
             .unwrap();
+        assert!(
+            prompt_with_grep.contains("grep"),
+            "`grep` tool should be mentioned in prompt when the tool is enabled"
+        );
+
+        // When the `grep` tool is disabled, it should not be mentioned in the prompt
+        let model_context = ModelContext {
+            available_tools: [].to_vec(),
+        };
+        let prompt_without_grep = prompt_builder
+            .generate_assistant_system_prompt(&project_context, &model_context)
+            .unwrap();
+        assert!(
+            !prompt_without_grep.contains("grep"),
+            "`grep` tool should not be mentioned in prompt when the tool is disabled"
+        );
+    }
+
+    #[test]
+    fn test_has_tool_helper() {
+        let mut handlebars = Handlebars::new();
+        handlebars.register_helper("has_tool", Box::new(PromptBuilder::has_tool_helper));
+        handlebars
+            .register_template_string(
+                "test_template",
+                "{{#if (has_tool 'grep')}}grep is enabled{{else}}grep is disabled{{/if}}",
+            )
+            .unwrap();
+
+        // grep available
+        let data = serde_json::json!({"available_tools": ["grep", "fetch"]});
+        let result = handlebars.render("test_template", &data).unwrap();
+        assert_eq!(result, "grep is enabled");
+
+        // grep not available
+        let data = serde_json::json!({"available_tools": ["terminal", "fetch"]});
+        let result = handlebars.render("test_template", &data).unwrap();
+        assert_eq!(result, "grep is disabled");
     }
 }