project_notifications_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::Result;
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::fmt::Write as _;
 10use std::sync::Arc;
 11use ui::IconName;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct ProjectUpdatesToolInput {}
 15
 16pub struct ProjectNotificationsTool;
 17
 18impl Tool for ProjectNotificationsTool {
 19    fn name(&self) -> String {
 20        "project_notifications".to_string()
 21    }
 22
 23    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 24        false
 25    }
 26    fn may_perform_edits(&self) -> bool {
 27        false
 28    }
 29    fn description(&self) -> String {
 30        include_str!("./project_notifications_tool/description.md").to_string()
 31    }
 32
 33    fn icon(&self) -> IconName {
 34        IconName::Envelope
 35    }
 36
 37    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 38        json_schema_for::<ProjectUpdatesToolInput>(format)
 39    }
 40
 41    fn ui_text(&self, _input: &serde_json::Value) -> String {
 42        "Check project notifications".into()
 43    }
 44
 45    fn run(
 46        self: Arc<Self>,
 47        _input: serde_json::Value,
 48        _request: Arc<LanguageModelRequest>,
 49        _project: Entity<Project>,
 50        action_log: Entity<ActionLog>,
 51        _model: Arc<dyn LanguageModel>,
 52        _window: Option<AnyWindowHandle>,
 53        cx: &mut App,
 54    ) -> ToolResult {
 55        let mut stale_files = String::new();
 56
 57        let action_log = action_log.read(cx);
 58
 59        for stale_file in action_log.stale_buffers(cx) {
 60            if let Some(file) = stale_file.read(cx).file() {
 61                writeln!(&mut stale_files, "- {}", file.path().display()).ok();
 62            }
 63        }
 64
 65        let response = if stale_files.is_empty() {
 66            "No new notifications".to_string()
 67        } else {
 68            // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
 69            const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
 70            format!("{HEADER}{stale_files}").replace("\r\n", "\n")
 71        };
 72
 73        Task::ready(Ok(response.into())).into()
 74    }
 75}
 76
 77#[cfg(test)]
 78mod tests {
 79    use super::*;
 80    use assistant_tool::ToolResultContent;
 81    use gpui::{AppContext, TestAppContext};
 82    use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
 83    use project::{FakeFs, Project};
 84    use serde_json::json;
 85    use settings::SettingsStore;
 86    use std::sync::Arc;
 87    use util::path;
 88
 89    #[gpui::test]
 90    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
 91        init_test(cx);
 92
 93        let fs = FakeFs::new(cx.executor());
 94        fs.insert_tree(
 95            path!("/test"),
 96            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
 97        )
 98        .await;
 99
100        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
101        let action_log = cx.new(|_| ActionLog::new(project.clone()));
102
103        let buffer_path = project
104            .read_with(cx, |project, cx| {
105                project.find_project_path("test/code.rs", cx)
106            })
107            .unwrap();
108
109        let buffer = project
110            .update(cx, |project, cx| {
111                project.open_buffer(buffer_path.clone(), cx)
112            })
113            .await
114            .unwrap();
115
116        // Start tracking the buffer
117        action_log.update(cx, |log, cx| {
118            log.buffer_read(buffer.clone(), cx);
119        });
120
121        // Run the tool before any changes
122        let tool = Arc::new(ProjectNotificationsTool);
123        let provider = Arc::new(FakeLanguageModelProvider);
124        let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
125        let request = Arc::new(LanguageModelRequest::default());
126        let tool_input = json!({});
127
128        let result = cx.update(|cx| {
129            tool.clone().run(
130                tool_input.clone(),
131                request.clone(),
132                project.clone(),
133                action_log.clone(),
134                model.clone(),
135                None,
136                cx,
137            )
138        });
139
140        let response = result.output.await.unwrap();
141        let response_text = match &response.content {
142            ToolResultContent::Text(text) => text.clone(),
143            _ => panic!("Expected text response"),
144        };
145        assert_eq!(
146            response_text.as_str(),
147            "No new notifications",
148            "Tool should return 'No new notifications' when no stale buffers"
149        );
150
151        // Modify the buffer (makes it stale)
152        buffer.update(cx, |buffer, cx| {
153            buffer.edit([(1..1, "\nChange!\n")], None, cx);
154        });
155
156        // Run the tool again
157        let result = cx.update(|cx| {
158            tool.run(
159                tool_input.clone(),
160                request.clone(),
161                project.clone(),
162                action_log,
163                model.clone(),
164                None,
165                cx,
166            )
167        });
168
169        // This time the buffer is stale, so the tool should return a notification
170        let response = result.output.await.unwrap();
171        let response_text = match &response.content {
172            ToolResultContent::Text(text) => text.clone(),
173            _ => panic!("Expected text response"),
174        };
175
176        let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
177        assert_eq!(
178            response_text.as_str(),
179            expected_content,
180            "Tool should return the stale buffer notification"
181        );
182    }
183
184    fn init_test(cx: &mut TestAppContext) {
185        cx.update(|cx| {
186            let settings_store = SettingsStore::test(cx);
187            cx.set_global(settings_store);
188            language::init(cx);
189            Project::init_settings(cx);
190            assistant_tool::init(cx);
191        });
192    }
193}