project_notifications_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::Result;
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{fmt::Write, sync::Arc};
 10use ui::IconName;
 11
 12#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 13pub struct ProjectUpdatesToolInput {}
 14
 15pub struct ProjectNotificationsTool;
 16
 17impl Tool for ProjectNotificationsTool {
 18    fn name(&self) -> String {
 19        "project_notifications".to_string()
 20    }
 21
 22    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 23        false
 24    }
 25    fn may_perform_edits(&self) -> bool {
 26        false
 27    }
 28    fn description(&self) -> String {
 29        include_str!("./project_notifications_tool/description.md").to_string()
 30    }
 31
 32    fn icon(&self) -> IconName {
 33        IconName::ToolNotification
 34    }
 35
 36    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 37        json_schema_for::<ProjectUpdatesToolInput>(format)
 38    }
 39
 40    fn ui_text(&self, _input: &serde_json::Value) -> String {
 41        "Check project notifications".into()
 42    }
 43
 44    fn run(
 45        self: Arc<Self>,
 46        _input: serde_json::Value,
 47        _request: Arc<LanguageModelRequest>,
 48        _project: Entity<Project>,
 49        action_log: Entity<ActionLog>,
 50        _model: Arc<dyn LanguageModel>,
 51        _window: Option<AnyWindowHandle>,
 52        cx: &mut App,
 53    ) -> ToolResult {
 54        let Some(user_edits_diff) =
 55            action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx))
 56        else {
 57            return result("No new notifications");
 58        };
 59
 60        // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
 61        const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
 62        const MAX_BYTES: usize = 8000;
 63        let diff = fit_patch_to_size(&user_edits_diff, MAX_BYTES);
 64        result(&format!("{HEADER}\n\n```diff\n{diff}\n```\n").replace("\r\n", "\n"))
 65    }
 66}
 67
 68fn result(response: &str) -> ToolResult {
 69    Task::ready(Ok(response.to_string().into())).into()
 70}
 71
 72/// Make sure that the patch fits into the size limit (in bytes).
 73/// Compress the patch by omitting some parts if needed.
 74/// Unified diff format is assumed.
 75fn fit_patch_to_size(patch: &str, max_size: usize) -> String {
 76    if patch.len() <= max_size {
 77        return patch.to_string();
 78    }
 79
 80    // Compression level 1: remove context lines in diff bodies, but
 81    // leave the counts and positions of inserted/deleted lines
 82    let mut current_size = patch.len();
 83    let mut file_patches = split_patch(&patch);
 84    file_patches.sort_by_key(|patch| patch.len());
 85    let compressed_patches = file_patches
 86        .iter()
 87        .rev()
 88        .map(|patch| {
 89            if current_size > max_size {
 90                let compressed = compress_patch(patch).unwrap_or_else(|_| patch.to_string());
 91                current_size -= patch.len() - compressed.len();
 92                compressed
 93            } else {
 94                patch.to_string()
 95            }
 96        })
 97        .collect::<Vec<_>>();
 98
 99    if current_size <= max_size {
100        return compressed_patches.join("\n\n");
101    }
102
103    // Compression level 2: list paths of the changed files only
104    let filenames = file_patches
105        .iter()
106        .map(|patch| {
107            let patch = diffy::Patch::from_str(patch).unwrap();
108            let path = patch
109                .modified()
110                .and_then(|path| path.strip_prefix("b/"))
111                .unwrap_or_default();
112            format!("- {path}\n")
113        })
114        .collect::<Vec<_>>();
115
116    filenames.join("")
117}
118
119/// Split a potentially multi-file patch into multiple single-file patches
120fn split_patch(patch: &str) -> Vec<String> {
121    let mut result = Vec::new();
122    let mut current_patch = String::new();
123
124    for line in patch.lines() {
125        if line.starts_with("---") && !current_patch.is_empty() {
126            result.push(current_patch.trim_end_matches('\n').into());
127            current_patch = String::new();
128        }
129        current_patch.push_str(line);
130        current_patch.push('\n');
131    }
132
133    if !current_patch.is_empty() {
134        result.push(current_patch.trim_end_matches('\n').into());
135    }
136
137    result
138}
139
140fn compress_patch(patch: &str) -> anyhow::Result<String> {
141    let patch = diffy::Patch::from_str(patch)?;
142    let mut out = String::new();
143
144    writeln!(out, "--- {}", patch.original().unwrap_or("a"))?;
145    writeln!(out, "+++ {}", patch.modified().unwrap_or("b"))?;
146
147    for hunk in patch.hunks() {
148        writeln!(out, "@@ -{} +{} @@", hunk.old_range(), hunk.new_range())?;
149        writeln!(out, "[...skipped...]")?;
150    }
151
152    Ok(out)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use assistant_tool::ToolResultContent;
159    use gpui::{AppContext, TestAppContext};
160    use indoc::indoc;
161    use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
162    use project::{FakeFs, Project};
163    use serde_json::json;
164    use settings::SettingsStore;
165    use std::sync::Arc;
166    use util::path;
167
168    #[gpui::test]
169    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
170        init_test(cx);
171
172        let fs = FakeFs::new(cx.executor());
173        fs.insert_tree(
174            path!("/test"),
175            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
176        )
177        .await;
178
179        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
180        let action_log = cx.new(|_| ActionLog::new(project.clone()));
181
182        let buffer_path = project
183            .read_with(cx, |project, cx| {
184                project.find_project_path("test/code.rs", cx)
185            })
186            .unwrap();
187
188        let buffer = project
189            .update(cx, |project, cx| {
190                project.open_buffer(buffer_path.clone(), cx)
191            })
192            .await
193            .unwrap();
194
195        // Start tracking the buffer
196        action_log.update(cx, |log, cx| {
197            log.buffer_read(buffer.clone(), cx);
198        });
199        cx.run_until_parked();
200
201        // Run the tool before any changes
202        let tool = Arc::new(ProjectNotificationsTool);
203        let provider = Arc::new(FakeLanguageModelProvider::default());
204        let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
205        let request = Arc::new(LanguageModelRequest::default());
206        let tool_input = json!({});
207
208        let result = cx.update(|cx| {
209            tool.clone().run(
210                tool_input.clone(),
211                request.clone(),
212                project.clone(),
213                action_log.clone(),
214                model.clone(),
215                None,
216                cx,
217            )
218        });
219        cx.run_until_parked();
220
221        let response = result.output.await.unwrap();
222        let response_text = match &response.content {
223            ToolResultContent::Text(text) => text.clone(),
224            _ => panic!("Expected text response"),
225        };
226        assert_eq!(
227            response_text.as_str(),
228            "No new notifications",
229            "Tool should return 'No new notifications' when no stale buffers"
230        );
231
232        // Modify the buffer (makes it stale)
233        buffer.update(cx, |buffer, cx| {
234            buffer.edit([(1..1, "\nChange!\n")], None, cx);
235        });
236        cx.run_until_parked();
237
238        // Run the tool again
239        let result = cx.update(|cx| {
240            tool.clone().run(
241                tool_input.clone(),
242                request.clone(),
243                project.clone(),
244                action_log.clone(),
245                model.clone(),
246                None,
247                cx,
248            )
249        });
250        cx.run_until_parked();
251
252        // This time the buffer is stale, so the tool should return a notification
253        let response = result.output.await.unwrap();
254        let response_text = match &response.content {
255            ToolResultContent::Text(text) => text.clone(),
256            _ => panic!("Expected text response"),
257        };
258
259        assert!(
260            response_text.contains("These files have changed"),
261            "Tool should return the stale buffer notification"
262        );
263        assert!(
264            response_text.contains("test/code.rs"),
265            "Tool should return the stale buffer notification"
266        );
267
268        // Run the tool once more without any changes - should get no new notifications
269        let result = cx.update(|cx| {
270            tool.run(
271                tool_input.clone(),
272                request.clone(),
273                project.clone(),
274                action_log,
275                model.clone(),
276                None,
277                cx,
278            )
279        });
280        cx.run_until_parked();
281
282        let response = result.output.await.unwrap();
283        let response_text = match &response.content {
284            ToolResultContent::Text(text) => text.clone(),
285            _ => panic!("Expected text response"),
286        };
287
288        assert_eq!(
289            response_text.as_str(),
290            "No new notifications",
291            "Tool should return 'No new notifications' when running again without changes"
292        );
293    }
294
295    #[test]
296    fn test_patch_compression() {
297        // Given a patch that doesn't fit into the size budget
298        let patch = indoc! {"
299       --- a/dir/test.txt
300       +++ b/dir/test.txt
301       @@ -1,3 +1,3 @@
302        line 1
303       -line 2
304       +CHANGED
305        line 3
306       @@ -10,2 +10,2 @@
307        line 10
308       -line 11
309       +line eleven
310
311
312       --- a/dir/another.txt
313       +++ b/dir/another.txt
314       @@ -100,1 +1,1 @@
315       -before
316       +after
317       "};
318
319        // When the size deficit can be compensated by dropping the body,
320        // then the body should be trimmed for larger files first
321        let limit = patch.len() - 10;
322        let compressed = fit_patch_to_size(patch, limit);
323        let expected = indoc! {"
324       --- a/dir/test.txt
325       +++ b/dir/test.txt
326       @@ -1,3 +1,3 @@
327       [...skipped...]
328       @@ -10,2 +10,2 @@
329       [...skipped...]
330
331
332       --- a/dir/another.txt
333       +++ b/dir/another.txt
334       @@ -100,1 +1,1 @@
335       -before
336       +after"};
337        assert_eq!(compressed, expected);
338
339        // When the size deficit is too large, then only file paths
340        // should be returned
341        let limit = 10;
342        let compressed = fit_patch_to_size(patch, limit);
343        let expected = indoc! {"
344       - dir/another.txt
345       - dir/test.txt
346       "};
347        assert_eq!(compressed, expected);
348    }
349
350    fn init_test(cx: &mut TestAppContext) {
351        cx.update(|cx| {
352            let settings_store = SettingsStore::test(cx);
353            cx.set_global(settings_store);
354            language::init(cx);
355            Project::init_settings(cx);
356            assistant_tool::init(cx);
357        });
358    }
359}