diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index ddddbfc5279ca23fb95527892e929b23b8cefbf6..20fc40f242831552630f1e15f59917fd80b1ecdb 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -160,6 +160,42 @@ async fn test_system_prompt(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_system_prompt_without_tools(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }) + .unwrap(); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!( + pending_completions.len(), + 1, + "unexpected pending completions: {:?}", + pending_completions + ); + + let pending_completion = pending_completions.pop().unwrap(); + assert_eq!(pending_completion.messages[0].role, Role::System); + + let system_message = &pending_completion.messages[0]; + let system_prompt = system_message.content[0].to_str().unwrap(); + assert!( + !system_prompt.contains("## Tool Use"), + "unexpected system message: {:?}", + system_message + ); + assert!( + !system_prompt.contains("## Fixing Diagnostics"), + "unexpected system message: {:?}", + system_message + ); +} + #[gpui::test] async fn test_prompt_caching(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 4016f3a5f53da95c0adca80ebfc5808addd55e09..64e512690beeaebd4a343bc5f2df473c795aed3f 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1816,9 +1816,15 @@ impl Thread { log::debug!("Completion intent: {:?}", completion_intent); log::debug!("Completion mode: {:?}", self.completion_mode); - let messages = self.build_request_messages(cx); + let available_tools: Vec<_> = self + .running_turn + .as_ref() + .map(|turn| turn.tools.keys().cloned().collect()) + .unwrap_or_default(); + + log::debug!("Request includes {} tools", available_tools.len()); + let messages = self.build_request_messages(available_tools, cx); log::debug!("Request will include {} messages", messages.len()); - log::debug!("Request includes {} tools", tools.len()); let request = LanguageModelRequest { thread_id: Some(self.id.to_string()), @@ -1909,7 +1915,11 @@ impl Thread { self.running_turn.as_ref()?.tools.get(name).cloned() } - fn build_request_messages(&self, cx: &App) -> Vec { + fn build_request_messages( + &self, + available_tools: Vec, + cx: &App, + ) -> Vec { log::trace!( "Building request messages from {} thread messages", self.messages.len() @@ -1917,7 +1927,7 @@ impl Thread { let system_prompt = SystemPromptTemplate { project: self.project_context.read(cx), - available_tools: self.tools.keys().cloned().collect(), + available_tools, } .render(&self.templates) .context("failed to build system prompt")