project_notifications_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::Result;
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::sync::Arc;
 10use ui::IconName;
 11
 12#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 13pub struct ProjectUpdatesToolInput {}
 14
 15pub struct ProjectNotificationsTool;
 16
 17impl Tool for ProjectNotificationsTool {
 18    fn name(&self) -> String {
 19        "project_notifications".to_string()
 20    }
 21
 22    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 23        false
 24    }
 25    fn may_perform_edits(&self) -> bool {
 26        false
 27    }
 28    fn description(&self) -> String {
 29        include_str!("./project_notifications_tool/description.md").to_string()
 30    }
 31
 32    fn icon(&self) -> IconName {
 33        IconName::ToolNotification
 34    }
 35
 36    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 37        json_schema_for::<ProjectUpdatesToolInput>(format)
 38    }
 39
 40    fn ui_text(&self, _input: &serde_json::Value) -> String {
 41        "Check project notifications".into()
 42    }
 43
 44    fn run(
 45        self: Arc<Self>,
 46        _input: serde_json::Value,
 47        _request: Arc<LanguageModelRequest>,
 48        _project: Entity<Project>,
 49        action_log: Entity<ActionLog>,
 50        _model: Arc<dyn LanguageModel>,
 51        _window: Option<AnyWindowHandle>,
 52        cx: &mut App,
 53    ) -> ToolResult {
 54        let Some(user_edits_diff) =
 55            action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx))
 56        else {
 57            return result("No new notifications");
 58        };
 59
 60        // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
 61        const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
 62        result(&format!("{HEADER}\n\n```diff\n{user_edits_diff}\n```\n").replace("\r\n", "\n"))
 63    }
 64}
 65
 66fn result(response: &str) -> ToolResult {
 67    Task::ready(Ok(response.to_string().into())).into()
 68}
 69
 70#[cfg(test)]
 71mod tests {
 72    use super::*;
 73    use assistant_tool::ToolResultContent;
 74    use gpui::{AppContext, TestAppContext};
 75    use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
 76    use project::{FakeFs, Project};
 77    use serde_json::json;
 78    use settings::SettingsStore;
 79    use std::sync::Arc;
 80    use util::path;
 81
 82    #[gpui::test]
 83    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
 84        init_test(cx);
 85
 86        let fs = FakeFs::new(cx.executor());
 87        fs.insert_tree(
 88            path!("/test"),
 89            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
 90        )
 91        .await;
 92
 93        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
 94        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 95
 96        let buffer_path = project
 97            .read_with(cx, |project, cx| {
 98                project.find_project_path("test/code.rs", cx)
 99            })
100            .unwrap();
101
102        let buffer = project
103            .update(cx, |project, cx| {
104                project.open_buffer(buffer_path.clone(), cx)
105            })
106            .await
107            .unwrap();
108
109        // Start tracking the buffer
110        action_log.update(cx, |log, cx| {
111            log.buffer_read(buffer.clone(), cx);
112        });
113        cx.run_until_parked();
114
115        // Run the tool before any changes
116        let tool = Arc::new(ProjectNotificationsTool);
117        let provider = Arc::new(FakeLanguageModelProvider);
118        let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
119        let request = Arc::new(LanguageModelRequest::default());
120        let tool_input = json!({});
121
122        let result = cx.update(|cx| {
123            tool.clone().run(
124                tool_input.clone(),
125                request.clone(),
126                project.clone(),
127                action_log.clone(),
128                model.clone(),
129                None,
130                cx,
131            )
132        });
133        cx.run_until_parked();
134
135        let response = result.output.await.unwrap();
136        let response_text = match &response.content {
137            ToolResultContent::Text(text) => text.clone(),
138            _ => panic!("Expected text response"),
139        };
140        assert_eq!(
141            response_text.as_str(),
142            "No new notifications",
143            "Tool should return 'No new notifications' when no stale buffers"
144        );
145
146        // Modify the buffer (makes it stale)
147        buffer.update(cx, |buffer, cx| {
148            buffer.edit([(1..1, "\nChange!\n")], None, cx);
149        });
150        cx.run_until_parked();
151
152        // Run the tool again
153        let result = cx.update(|cx| {
154            tool.clone().run(
155                tool_input.clone(),
156                request.clone(),
157                project.clone(),
158                action_log.clone(),
159                model.clone(),
160                None,
161                cx,
162            )
163        });
164        cx.run_until_parked();
165
166        // This time the buffer is stale, so the tool should return a notification
167        let response = result.output.await.unwrap();
168        let response_text = match &response.content {
169            ToolResultContent::Text(text) => text.clone(),
170            _ => panic!("Expected text response"),
171        };
172
173        assert!(
174            response_text.contains("These files have changed"),
175            "Tool should return the stale buffer notification"
176        );
177        assert!(
178            response_text.contains("test/code.rs"),
179            "Tool should return the stale buffer notification"
180        );
181
182        // Run the tool once more without any changes - should get no new notifications
183        let result = cx.update(|cx| {
184            tool.run(
185                tool_input.clone(),
186                request.clone(),
187                project.clone(),
188                action_log,
189                model.clone(),
190                None,
191                cx,
192            )
193        });
194        cx.run_until_parked();
195
196        let response = result.output.await.unwrap();
197        let response_text = match &response.content {
198            ToolResultContent::Text(text) => text.clone(),
199            _ => panic!("Expected text response"),
200        };
201
202        assert_eq!(
203            response_text.as_str(),
204            "No new notifications",
205            "Tool should return 'No new notifications' when running again without changes"
206        );
207    }
208
209    fn init_test(cx: &mut TestAppContext) {
210        cx.update(|cx| {
211            let settings_store = SettingsStore::test(cx);
212            cx.set_global(settings_store);
213            language::init(cx);
214            Project::init_settings(cx);
215            assistant_tool::init(cx);
216        });
217    }
218}