project_notifications_tool.rs

  1use crate::schema::json_schema_for;
  2use action_log::ActionLog;
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolResult};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::{fmt::Write, sync::Arc};
 11use ui::IconName;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct ProjectUpdatesToolInput {}
 15
 16pub struct ProjectNotificationsTool;
 17
 18impl Tool for ProjectNotificationsTool {
 19    fn name(&self) -> String {
 20        "project_notifications".to_string()
 21    }
 22
 23    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 24        false
 25    }
 26    fn may_perform_edits(&self) -> bool {
 27        false
 28    }
 29    fn description(&self) -> String {
 30        include_str!("./project_notifications_tool/description.md").to_string()
 31    }
 32
 33    fn icon(&self) -> IconName {
 34        IconName::ToolNotification
 35    }
 36
 37    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 38        json_schema_for::<ProjectUpdatesToolInput>(format)
 39    }
 40
 41    fn ui_text(&self, _input: &serde_json::Value) -> String {
 42        "Check project notifications".into()
 43    }
 44
 45    fn run(
 46        self: Arc<Self>,
 47        _input: serde_json::Value,
 48        _request: Arc<LanguageModelRequest>,
 49        _project: Entity<Project>,
 50        action_log: Entity<ActionLog>,
 51        _model: Arc<dyn LanguageModel>,
 52        _window: Option<AnyWindowHandle>,
 53        cx: &mut App,
 54    ) -> ToolResult {
 55        let Some(user_edits_diff) =
 56            action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx))
 57        else {
 58            return result("No new notifications");
 59        };
 60
 61        // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
 62        const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
 63        const MAX_BYTES: usize = 8000;
 64        let diff = fit_patch_to_size(&user_edits_diff, MAX_BYTES);
 65        result(&format!("{HEADER}\n\n```diff\n{diff}\n```\n").replace("\r\n", "\n"))
 66    }
 67}
 68
 69fn result(response: &str) -> ToolResult {
 70    Task::ready(Ok(response.to_string().into())).into()
 71}
 72
 73/// Make sure that the patch fits into the size limit (in bytes).
 74/// Compress the patch by omitting some parts if needed.
 75/// Unified diff format is assumed.
 76fn fit_patch_to_size(patch: &str, max_size: usize) -> String {
 77    if patch.len() <= max_size {
 78        return patch.to_string();
 79    }
 80
 81    // Compression level 1: remove context lines in diff bodies, but
 82    // leave the counts and positions of inserted/deleted lines
 83    let mut current_size = patch.len();
 84    let mut file_patches = split_patch(patch);
 85    file_patches.sort_by_key(|patch| patch.len());
 86    let compressed_patches = file_patches
 87        .iter()
 88        .rev()
 89        .map(|patch| {
 90            if current_size > max_size {
 91                let compressed = compress_patch(patch).unwrap_or_else(|_| patch.to_string());
 92                current_size -= patch.len() - compressed.len();
 93                compressed
 94            } else {
 95                patch.to_string()
 96            }
 97        })
 98        .collect::<Vec<_>>();
 99
100    if current_size <= max_size {
101        return compressed_patches.join("\n\n");
102    }
103
104    // Compression level 2: list paths of the changed files only
105    let filenames = file_patches
106        .iter()
107        .map(|patch| {
108            let patch = diffy::Patch::from_str(patch).unwrap();
109            let path = patch
110                .modified()
111                .and_then(|path| path.strip_prefix("b/"))
112                .unwrap_or_default();
113            format!("- {path}\n")
114        })
115        .collect::<Vec<_>>();
116
117    filenames.join("")
118}
119
120/// Split a potentially multi-file patch into multiple single-file patches
121fn split_patch(patch: &str) -> Vec<String> {
122    let mut result = Vec::new();
123    let mut current_patch = String::new();
124
125    for line in patch.lines() {
126        if line.starts_with("---") && !current_patch.is_empty() {
127            result.push(current_patch.trim_end_matches('\n').into());
128            current_patch = String::new();
129        }
130        current_patch.push_str(line);
131        current_patch.push('\n');
132    }
133
134    if !current_patch.is_empty() {
135        result.push(current_patch.trim_end_matches('\n').into());
136    }
137
138    result
139}
140
141fn compress_patch(patch: &str) -> anyhow::Result<String> {
142    let patch = diffy::Patch::from_str(patch)?;
143    let mut out = String::new();
144
145    writeln!(out, "--- {}", patch.original().unwrap_or("a"))?;
146    writeln!(out, "+++ {}", patch.modified().unwrap_or("b"))?;
147
148    for hunk in patch.hunks() {
149        writeln!(out, "@@ -{} +{} @@", hunk.old_range(), hunk.new_range())?;
150        writeln!(out, "[...skipped...]")?;
151    }
152
153    Ok(out)
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use assistant_tool::ToolResultContent;
160    use gpui::{AppContext, TestAppContext};
161    use indoc::indoc;
162    use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
163    use project::{FakeFs, Project};
164    use serde_json::json;
165    use settings::SettingsStore;
166    use std::sync::Arc;
167    use util::path;
168
169    #[gpui::test]
170    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
171        init_test(cx);
172
173        let fs = FakeFs::new(cx.executor());
174        fs.insert_tree(
175            path!("/test"),
176            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
177        )
178        .await;
179
180        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
181        let action_log = cx.new(|_| ActionLog::new(project.clone()));
182
183        let buffer_path = project
184            .read_with(cx, |project, cx| {
185                project.find_project_path("test/code.rs", cx)
186            })
187            .unwrap();
188
189        let buffer = project
190            .update(cx, |project, cx| {
191                project.open_buffer(buffer_path.clone(), cx)
192            })
193            .await
194            .unwrap();
195
196        // Start tracking the buffer
197        action_log.update(cx, |log, cx| {
198            log.buffer_read(buffer.clone(), cx);
199        });
200        cx.run_until_parked();
201
202        // Run the tool before any changes
203        let tool = Arc::new(ProjectNotificationsTool);
204        let provider = Arc::new(FakeLanguageModelProvider::default());
205        let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
206        let request = Arc::new(LanguageModelRequest::default());
207        let tool_input = json!({});
208
209        let result = cx.update(|cx| {
210            tool.clone().run(
211                tool_input.clone(),
212                request.clone(),
213                project.clone(),
214                action_log.clone(),
215                model.clone(),
216                None,
217                cx,
218            )
219        });
220        cx.run_until_parked();
221
222        let response = result.output.await.unwrap();
223        let response_text = match &response.content {
224            ToolResultContent::Text(text) => text.clone(),
225            _ => panic!("Expected text response"),
226        };
227        assert_eq!(
228            response_text.as_str(),
229            "No new notifications",
230            "Tool should return 'No new notifications' when no stale buffers"
231        );
232
233        // Modify the buffer (makes it stale)
234        buffer.update(cx, |buffer, cx| {
235            buffer.edit([(1..1, "\nChange!\n")], None, cx);
236        });
237        cx.run_until_parked();
238
239        // Run the tool again
240        let result = cx.update(|cx| {
241            tool.clone().run(
242                tool_input.clone(),
243                request.clone(),
244                project.clone(),
245                action_log.clone(),
246                model.clone(),
247                None,
248                cx,
249            )
250        });
251        cx.run_until_parked();
252
253        // This time the buffer is stale, so the tool should return a notification
254        let response = result.output.await.unwrap();
255        let response_text = match &response.content {
256            ToolResultContent::Text(text) => text.clone(),
257            _ => panic!("Expected text response"),
258        };
259
260        assert!(
261            response_text.contains("These files have changed"),
262            "Tool should return the stale buffer notification"
263        );
264        assert!(
265            response_text.contains("test/code.rs"),
266            "Tool should return the stale buffer notification"
267        );
268
269        // Run the tool once more without any changes - should get no new notifications
270        let result = cx.update(|cx| {
271            tool.run(
272                tool_input.clone(),
273                request.clone(),
274                project.clone(),
275                action_log,
276                model.clone(),
277                None,
278                cx,
279            )
280        });
281        cx.run_until_parked();
282
283        let response = result.output.await.unwrap();
284        let response_text = match &response.content {
285            ToolResultContent::Text(text) => text.clone(),
286            _ => panic!("Expected text response"),
287        };
288
289        assert_eq!(
290            response_text.as_str(),
291            "No new notifications",
292            "Tool should return 'No new notifications' when running again without changes"
293        );
294    }
295
296    #[test]
297    fn test_patch_compression() {
298        // Given a patch that doesn't fit into the size budget
299        let patch = indoc! {"
300       --- a/dir/test.txt
301       +++ b/dir/test.txt
302       @@ -1,3 +1,3 @@
303        line 1
304       -line 2
305       +CHANGED
306        line 3
307       @@ -10,2 +10,2 @@
308        line 10
309       -line 11
310       +line eleven
311
312
313       --- a/dir/another.txt
314       +++ b/dir/another.txt
315       @@ -100,1 +1,1 @@
316       -before
317       +after
318       "};
319
320        // When the size deficit can be compensated by dropping the body,
321        // then the body should be trimmed for larger files first
322        let limit = patch.len() - 10;
323        let compressed = fit_patch_to_size(patch, limit);
324        let expected = indoc! {"
325       --- a/dir/test.txt
326       +++ b/dir/test.txt
327       @@ -1,3 +1,3 @@
328       [...skipped...]
329       @@ -10,2 +10,2 @@
330       [...skipped...]
331
332
333       --- a/dir/another.txt
334       +++ b/dir/another.txt
335       @@ -100,1 +1,1 @@
336       -before
337       +after"};
338        assert_eq!(compressed, expected);
339
340        // When the size deficit is too large, then only file paths
341        // should be returned
342        let limit = 10;
343        let compressed = fit_patch_to_size(patch, limit);
344        let expected = indoc! {"
345       - dir/another.txt
346       - dir/test.txt
347       "};
348        assert_eq!(compressed, expected);
349    }
350
351    fn init_test(cx: &mut TestAppContext) {
352        cx.update(|cx| {
353            let settings_store = SettingsStore::test(cx);
354            cx.set_global(settings_store);
355            language::init(cx);
356            Project::init_settings(cx);
357            assistant_tool::init(cx);
358        });
359    }
360}