1use acp_thread::AcpThread;
2use anyhow::Result;
3use context_server::{
4 listener::{McpServerTool, ToolResponse},
5 types::{ToolAnnotations, ToolResponseContent},
6};
7use gpui::{AsyncApp, WeakEntity};
8use language::unified_diff;
9use util::markdown::MarkdownCodeBlock;
10
11use crate::tools::EditToolParams;
12
13#[derive(Clone)]
14pub struct EditTool {
15 thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
16}
17
18impl EditTool {
19 pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
20 Self { thread_rx }
21 }
22}
23
24impl McpServerTool for EditTool {
25 type Input = EditToolParams;
26 type Output = ();
27
28 const NAME: &'static str = "Edit";
29
30 fn annotations(&self) -> ToolAnnotations {
31 ToolAnnotations {
32 title: Some("Edit file".to_string()),
33 read_only_hint: Some(false),
34 destructive_hint: Some(false),
35 open_world_hint: Some(false),
36 idempotent_hint: Some(false),
37 }
38 }
39
40 async fn run(
41 &self,
42 input: Self::Input,
43 cx: &mut AsyncApp,
44 ) -> Result<ToolResponse<Self::Output>> {
45 let mut thread_rx = self.thread_rx.clone();
46 let Some(thread) = thread_rx.recv().await?.upgrade() else {
47 anyhow::bail!("Thread closed");
48 };
49
50 let content = thread
51 .update(cx, |thread, cx| {
52 thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
53 })?
54 .await?;
55
56 let (new_content, diff) = cx
57 .background_executor()
58 .spawn(async move {
59 let new_content = content.replace(&input.old_text, &input.new_text);
60 if new_content == content {
61 return Err(anyhow::anyhow!("Failed to find `old_text`",));
62 }
63 let diff = unified_diff(&content, &new_content);
64
65 Ok((new_content, diff))
66 })
67 .await?;
68
69 thread
70 .update(cx, |thread, cx| {
71 thread.write_text_file(input.abs_path, new_content, cx)
72 })?
73 .await?;
74
75 Ok(ToolResponse {
76 content: vec![ToolResponseContent::Text {
77 text: MarkdownCodeBlock {
78 tag: "diff",
79 text: diff.as_str().trim_end_matches('\n'),
80 }
81 .to_string(),
82 }],
83 structured_content: (),
84 })
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::rc::Rc;
91
92 use acp_thread::{AgentConnection, StubAgentConnection};
93 use gpui::{Entity, TestAppContext};
94 use indoc::indoc;
95 use project::{FakeFs, Project};
96 use serde_json::json;
97 use settings::SettingsStore;
98 use util::path;
99
100 use super::*;
101
102 #[gpui::test]
103 async fn old_text_not_found(cx: &mut TestAppContext) {
104 let (_thread, tool) = init_test(cx).await;
105
106 let result = tool
107 .run(
108 EditToolParams {
109 abs_path: path!("/root/file.txt").into(),
110 old_text: "hi".into(),
111 new_text: "bye".into(),
112 },
113 &mut cx.to_async(),
114 )
115 .await;
116
117 assert_eq!(result.unwrap_err().to_string(), "Failed to find `old_text`");
118 }
119
120 #[gpui::test]
121 async fn found_and_replaced(cx: &mut TestAppContext) {
122 let (_thread, tool) = init_test(cx).await;
123
124 let result = tool
125 .run(
126 EditToolParams {
127 abs_path: path!("/root/file.txt").into(),
128 old_text: "hello".into(),
129 new_text: "hi".into(),
130 },
131 &mut cx.to_async(),
132 )
133 .await;
134
135 assert_eq!(
136 result.unwrap().content[0].text().unwrap(),
137 indoc! {
138 r"
139 ```diff
140 @@ -1,1 +1,1 @@
141 -hello
142 +hi
143 ```
144 "
145 }
146 );
147 }
148
149 async fn init_test(cx: &mut TestAppContext) -> (Entity<AcpThread>, EditTool) {
150 cx.update(|cx| {
151 let settings_store = SettingsStore::test(cx);
152 cx.set_global(settings_store);
153 language::init(cx);
154 Project::init_settings(cx);
155 });
156
157 let connection = Rc::new(StubAgentConnection::new());
158 let fs = FakeFs::new(cx.executor());
159 fs.insert_tree(
160 path!("/root"),
161 json!({
162 "file.txt": "hello"
163 }),
164 )
165 .await;
166 let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
167 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
168
169 let thread = cx
170 .update(|cx| connection.new_thread(project, path!("/test").as_ref(), cx))
171 .await
172 .unwrap();
173
174 thread_tx.send(thread.downgrade()).unwrap();
175
176 (thread, EditTool::new(thread_rx))
177 }
178}