@@ -25,8 +25,8 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
- LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
- Role, SelectedModel, StopReason, TokenUsage,
+ LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
+ PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@@ -45,7 +45,7 @@ use std::{
time::{Duration, Instant},
};
use thiserror::Error;
-use util::{ResultExt as _, post_inc};
+use util::{ResultExt as _, debug_panic, post_inc};
use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
@@ -1248,6 +1248,8 @@ impl Thread {
self.remaining_turns -= 1;
+ self.flush_notifications(model.clone(), intent, cx);
+
let request = self.to_completion_request(model.clone(), intent, cx);
self.stream_completion(request, model, intent, window, cx);
@@ -1481,6 +1483,110 @@ impl Thread {
request
}
+ /// Insert auto-generated notifications (if any) to the thread
+ fn flush_notifications(
+ &mut self,
+ model: Arc<dyn LanguageModel>,
+ intent: CompletionIntent,
+ cx: &mut Context<Self>,
+ ) {
+ match intent {
+ CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
+ if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
+ cx.emit(ThreadEvent::ToolFinished {
+ tool_use_id: pending_tool_use.id.clone(),
+ pending_tool_use: Some(pending_tool_use),
+ });
+ }
+ }
+ CompletionIntent::ThreadSummarization
+ | CompletionIntent::ThreadContextSummarization
+ | CompletionIntent::CreateFile
+ | CompletionIntent::EditFile
+ | CompletionIntent::InlineAssist
+ | CompletionIntent::TerminalInlineAssist
+ | CompletionIntent::GenerateGitCommitMessage => {}
+ };
+ }
+
+ fn attach_tracked_files_state(
+ &mut self,
+ model: Arc<dyn LanguageModel>,
+ cx: &mut App,
+ ) -> Option<PendingToolUse> {
+ let action_log = self.action_log.read(cx);
+
+ action_log.stale_buffers(cx).next()?;
+
+ // Represent notification as a simulated `project_notifications` tool call
+ let tool_name = Arc::from("project_notifications");
+ let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
+ debug_panic!("`project_notifications` tool not found");
+ return None;
+ };
+
+ if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
+ return None;
+ }
+
+ let input = serde_json::json!({});
+ let request = Arc::new(LanguageModelRequest::default()); // unused
+ let window = None;
+ let tool_result = tool.run(
+ input,
+ request,
+ self.project.clone(),
+ self.action_log.clone(),
+ model.clone(),
+ window,
+ cx,
+ );
+
+ let tool_use_id =
+ LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
+
+ let tool_use = LanguageModelToolUse {
+ id: tool_use_id.clone(),
+ name: tool_name.clone(),
+ raw_input: "{}".to_string(),
+ input: serde_json::json!({}),
+ is_input_complete: true,
+ };
+
+ let tool_output = cx.background_executor().block(tool_result.output);
+
+ // Attach a project_notification tool call to the latest existing
+ // Assistant message. We cannot create a new Assistant message
+ // because thinking models require a `thinking` block that we
+ // cannot mock. We cannot send a notification as a normal
+ // (non-tool-use) User message because this distracts Agent
+ // too much.
+ let tool_message_id = self
+ .messages
+ .iter()
+ .enumerate()
+ .rfind(|(_, message)| message.role == Role::Assistant)
+ .map(|(_, message)| message.id)?;
+
+ let tool_use_metadata = ToolUseMetadata {
+ model: model.clone(),
+ thread_id: self.id.clone(),
+ prompt_id: self.last_prompt_id.clone(),
+ };
+
+ self.tool_use
+ .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
+
+ let pending_tool_use = self.tool_use.insert_tool_output(
+ tool_use_id.clone(),
+ tool_name,
+ tool_output,
+ self.configured_model.as_ref(),
+ );
+
+ pending_tool_use
+ }
+
pub fn stream_completion(
&mut self,
request: LanguageModelRequest,
@@ -3156,10 +3262,13 @@ mod tests {
const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
use assistant_tool::ToolRegistry;
+ use assistant_tools;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use gpui::TestAppContext;
+ use http_client;
+ use indoc::indoc;
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
use language_model::{
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
@@ -3487,6 +3596,105 @@ fn main() {{
);
}
+ #[gpui::test]
+ async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let project = create_test_project(
+ cx,
+ json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+ )
+ .await;
+
+ let (_workspace, _thread_store, thread, context_store, model) =
+ setup_test_environment(cx, project.clone()).await;
+
+ // Add a buffer to the context. This will be a tracked buffer
+ let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
+ .await
+ .unwrap();
+
+ let context = context_store
+ .read_with(cx, |store, _| store.context().next().cloned())
+ .unwrap();
+ let loaded_context = cx
+ .update(|cx| load_context(vec![context], &project, &None, cx))
+ .await;
+
+ // Insert user message and assistant response
+ thread.update(cx, |thread, cx| {
+ thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
+ thread.insert_assistant_message(
+ vec![MessageSegment::Text("This code prints 42.".into())],
+ cx,
+ );
+ });
+
+ // We shouldn't have a stale buffer notification yet
+ let notification = thread.read_with(cx, |thread, _| {
+ find_tool_use(thread, "project_notifications")
+ });
+ assert!(
+ notification.is_none(),
+ "Should not have stale buffer notification before buffer is modified"
+ );
+
+ // Modify the buffer
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(
+ [(1..1, "\n println!(\"Added a new line\");\n")],
+ None,
+ cx,
+ );
+ });
+
+ // Insert another user message
+ thread.update(cx, |thread, cx| {
+ thread.insert_user_message(
+ "What does the code do now?",
+ ContextLoadResult::default(),
+ None,
+ Vec::new(),
+ cx,
+ )
+ });
+
+ // Check for the stale buffer warning
+ thread.update(cx, |thread, cx| {
+ thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
+ });
+
+ let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
+ find_tool_use(thread, "project_notifications")
+ }) else {
+ panic!("Should have a `project_notifications` tool use");
+ };
+
+ let Some(notification_content) = notification_result.content.to_str() else {
+ panic!("`project_notifications` should return text");
+ };
+
+ let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
+
+ These files have changed since the last read:
+ - code.rs
+ "};
+ assert_eq!(notification_content, expected_content);
+ }
+
+ fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
+ thread
+ .messages()
+ .filter_map(|message| {
+ thread
+ .tool_results_for_message(message.id)
+ .into_iter()
+ .find(|result| result.tool_name == tool_name.into())
+ })
+ .next()
+ .cloned()
+ }
+
#[gpui::test]
async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
init_test_settings(cx);
@@ -5052,6 +5260,14 @@ fn main() {{
language_model::init_settings(cx);
ThemeSettings::register(cx);
ToolRegistry::default_global(cx);
+ assistant_tool::init(cx);
+
+ let http_client = Arc::new(http_client::HttpClientWithUrl::new(
+ http_client::FakeHttpClient::with_200_response(),
+ "http://localhost".to_string(),
+ None,
+ ));
+ assistant_tools::init(http_client, cx);
});
}
@@ -0,0 +1,193 @@
+use crate::schema::json_schema_for;
+use anyhow::Result;
+use assistant_tool::{ActionLog, Tool, ToolResult};
+use gpui::{AnyWindowHandle, App, Entity, Task};
+use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
+use project::Project;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::fmt::Write as _;
+use std::sync::Arc;
+use ui::IconName;
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct ProjectUpdatesToolInput {}
+
+pub struct ProjectNotificationsTool;
+
+impl Tool for ProjectNotificationsTool {
+ fn name(&self) -> String {
+ "project_notifications".to_string()
+ }
+
+ fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
+ false
+ }
+ fn may_perform_edits(&self) -> bool {
+ false
+ }
+ fn description(&self) -> String {
+ include_str!("./project_notifications_tool/description.md").to_string()
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::Envelope
+ }
+
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
+ json_schema_for::<ProjectUpdatesToolInput>(format)
+ }
+
+ fn ui_text(&self, _input: &serde_json::Value) -> String {
+ "Check project notifications".into()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ _input: serde_json::Value,
+ _request: Arc<LanguageModelRequest>,
+ _project: Entity<Project>,
+ action_log: Entity<ActionLog>,
+ _model: Arc<dyn LanguageModel>,
+ _window: Option<AnyWindowHandle>,
+ cx: &mut App,
+ ) -> ToolResult {
+ let mut stale_files = String::new();
+
+ let action_log = action_log.read(cx);
+
+ for stale_file in action_log.stale_buffers(cx) {
+ if let Some(file) = stale_file.read(cx).file() {
+ writeln!(&mut stale_files, "- {}", file.path().display()).ok();
+ }
+ }
+
+ let response = if stale_files.is_empty() {
+ "No new notifications".to_string()
+ } else {
+ // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
+ const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
+ format!("{HEADER}{stale_files}").replace("\r\n", "\n")
+ };
+
+ Task::ready(Ok(response.into())).into()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use assistant_tool::ToolResultContent;
+ use gpui::{AppContext, TestAppContext};
+ use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
+ use project::{FakeFs, Project};
+ use serde_json::json;
+ use settings::SettingsStore;
+ use std::sync::Arc;
+ use util::path;
+
+ #[gpui::test]
+ async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/test"),
+ json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let buffer_path = project
+ .read_with(cx, |project, cx| {
+ project.find_project_path("test/code.rs", cx)
+ })
+ .unwrap();
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(buffer_path.clone(), cx)
+ })
+ .await
+ .unwrap();
+
+ // Start tracking the buffer
+ action_log.update(cx, |log, cx| {
+ log.buffer_read(buffer.clone(), cx);
+ });
+
+ // Run the tool before any changes
+ let tool = Arc::new(ProjectNotificationsTool);
+ let provider = Arc::new(FakeLanguageModelProvider);
+ let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
+ let request = Arc::new(LanguageModelRequest::default());
+ let tool_input = json!({});
+
+ let result = cx.update(|cx| {
+ tool.clone().run(
+ tool_input.clone(),
+ request.clone(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ });
+
+ let response = result.output.await.unwrap();
+ let response_text = match &response.content {
+ ToolResultContent::Text(text) => text.clone(),
+ _ => panic!("Expected text response"),
+ };
+ assert_eq!(
+ response_text.as_str(),
+ "No new notifications",
+ "Tool should return 'No new notifications' when no stale buffers"
+ );
+
+ // Modify the buffer (makes it stale)
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(1..1, "\nChange!\n")], None, cx);
+ });
+
+ // Run the tool again
+ let result = cx.update(|cx| {
+ tool.run(
+ tool_input.clone(),
+ request.clone(),
+ project.clone(),
+ action_log,
+ model.clone(),
+ None,
+ cx,
+ )
+ });
+
+ // This time the buffer is stale, so the tool should return a notification
+ let response = result.output.await.unwrap();
+ let response_text = match &response.content {
+ ToolResultContent::Text(text) => text.clone(),
+ _ => panic!("Expected text response"),
+ };
+
+ let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
+ assert_eq!(
+ response_text.as_str(),
+ expected_content,
+ "Tool should return the stale buffer notification"
+ );
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ Project::init_settings(cx);
+ assistant_tool::init(cx);
+ });
+ }
+}