project_notifications_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::Result;
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::fmt::Write as _;
 10use std::sync::Arc;
 11use ui::IconName;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct ProjectUpdatesToolInput {}
 15
 16pub struct ProjectNotificationsTool;
 17
 18impl Tool for ProjectNotificationsTool {
 19    fn name(&self) -> String {
 20        "project_notifications".to_string()
 21    }
 22
 23    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 24        false
 25    }
 26    fn may_perform_edits(&self) -> bool {
 27        false
 28    }
 29    fn description(&self) -> String {
 30        include_str!("./project_notifications_tool/description.md").to_string()
 31    }
 32
 33    fn icon(&self) -> IconName {
 34        IconName::Envelope
 35    }
 36
 37    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 38        json_schema_for::<ProjectUpdatesToolInput>(format)
 39    }
 40
 41    fn ui_text(&self, _input: &serde_json::Value) -> String {
 42        "Check project notifications".into()
 43    }
 44
 45    fn run(
 46        self: Arc<Self>,
 47        _input: serde_json::Value,
 48        _request: Arc<LanguageModelRequest>,
 49        _project: Entity<Project>,
 50        action_log: Entity<ActionLog>,
 51        _model: Arc<dyn LanguageModel>,
 52        _window: Option<AnyWindowHandle>,
 53        cx: &mut App,
 54    ) -> ToolResult {
 55        let mut stale_files = String::new();
 56        let mut notified_buffers = Vec::new();
 57
 58        for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) {
 59            if let Some(file) = stale_file.read(cx).file() {
 60                writeln!(&mut stale_files, "- {}", file.path().display()).ok();
 61                notified_buffers.push(stale_file.clone());
 62            }
 63        }
 64
 65        if !notified_buffers.is_empty() {
 66            action_log.update(cx, |log, cx| {
 67                log.mark_buffers_as_notified(notified_buffers, cx);
 68            });
 69        }
 70
 71        let response = if stale_files.is_empty() {
 72            "No new notifications".to_string()
 73        } else {
 74            // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
 75            const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
 76            format!("{HEADER}{stale_files}").replace("\r\n", "\n")
 77        };
 78
 79        Task::ready(Ok(response.into())).into()
 80    }
 81}
 82
 83#[cfg(test)]
 84mod tests {
 85    use super::*;
 86    use assistant_tool::ToolResultContent;
 87    use gpui::{AppContext, TestAppContext};
 88    use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
 89    use project::{FakeFs, Project};
 90    use serde_json::json;
 91    use settings::SettingsStore;
 92    use std::sync::Arc;
 93    use util::path;
 94
 95    #[gpui::test]
 96    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
 97        init_test(cx);
 98
 99        let fs = FakeFs::new(cx.executor());
100        fs.insert_tree(
101            path!("/test"),
102            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
103        )
104        .await;
105
106        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
107        let action_log = cx.new(|_| ActionLog::new(project.clone()));
108
109        let buffer_path = project
110            .read_with(cx, |project, cx| {
111                project.find_project_path("test/code.rs", cx)
112            })
113            .unwrap();
114
115        let buffer = project
116            .update(cx, |project, cx| {
117                project.open_buffer(buffer_path.clone(), cx)
118            })
119            .await
120            .unwrap();
121
122        // Start tracking the buffer
123        action_log.update(cx, |log, cx| {
124            log.buffer_read(buffer.clone(), cx);
125        });
126
127        // Run the tool before any changes
128        let tool = Arc::new(ProjectNotificationsTool);
129        let provider = Arc::new(FakeLanguageModelProvider);
130        let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
131        let request = Arc::new(LanguageModelRequest::default());
132        let tool_input = json!({});
133
134        let result = cx.update(|cx| {
135            tool.clone().run(
136                tool_input.clone(),
137                request.clone(),
138                project.clone(),
139                action_log.clone(),
140                model.clone(),
141                None,
142                cx,
143            )
144        });
145
146        let response = result.output.await.unwrap();
147        let response_text = match &response.content {
148            ToolResultContent::Text(text) => text.clone(),
149            _ => panic!("Expected text response"),
150        };
151        assert_eq!(
152            response_text.as_str(),
153            "No new notifications",
154            "Tool should return 'No new notifications' when no stale buffers"
155        );
156
157        // Modify the buffer (makes it stale)
158        buffer.update(cx, |buffer, cx| {
159            buffer.edit([(1..1, "\nChange!\n")], None, cx);
160        });
161
162        // Run the tool again
163        let result = cx.update(|cx| {
164            tool.clone().run(
165                tool_input.clone(),
166                request.clone(),
167                project.clone(),
168                action_log.clone(),
169                model.clone(),
170                None,
171                cx,
172            )
173        });
174
175        // This time the buffer is stale, so the tool should return a notification
176        let response = result.output.await.unwrap();
177        let response_text = match &response.content {
178            ToolResultContent::Text(text) => text.clone(),
179            _ => panic!("Expected text response"),
180        };
181
182        let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
183        assert_eq!(
184            response_text.as_str(),
185            expected_content,
186            "Tool should return the stale buffer notification"
187        );
188
189        // Run the tool once more without any changes - should get no new notifications
190        let result = cx.update(|cx| {
191            tool.run(
192                tool_input.clone(),
193                request.clone(),
194                project.clone(),
195                action_log,
196                model.clone(),
197                None,
198                cx,
199            )
200        });
201
202        let response = result.output.await.unwrap();
203        let response_text = match &response.content {
204            ToolResultContent::Text(text) => text.clone(),
205            _ => panic!("Expected text response"),
206        };
207
208        assert_eq!(
209            response_text.as_str(),
210            "No new notifications",
211            "Tool should return 'No new notifications' when running again without changes"
212        );
213    }
214
215    fn init_test(cx: &mut TestAppContext) {
216        cx.update(|cx| {
217            let settings_store = SettingsStore::test(cx);
218            cx.set_global(settings_store);
219            language::init(cx);
220            Project::init_settings(cx);
221            assistant_tool::init(cx);
222        });
223    }
224}