edit_tool.rs

  1use acp_thread::AcpThread;
  2use anyhow::Result;
  3use context_server::{
  4    listener::{McpServerTool, ToolResponse},
  5    types::{ToolAnnotations, ToolResponseContent},
  6};
  7use gpui::{AsyncApp, WeakEntity};
  8use language::unified_diff;
  9use util::markdown::MarkdownCodeBlock;
 10
 11use crate::tools::EditToolParams;
 12
 13#[derive(Clone)]
 14pub struct EditTool {
 15    thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 16}
 17
 18impl EditTool {
 19    pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
 20        Self { thread_rx }
 21    }
 22}
 23
 24impl McpServerTool for EditTool {
 25    type Input = EditToolParams;
 26    type Output = ();
 27
 28    const NAME: &'static str = "Edit";
 29
 30    fn annotations(&self) -> ToolAnnotations {
 31        ToolAnnotations {
 32            title: Some("Edit file".to_string()),
 33            read_only_hint: Some(false),
 34            destructive_hint: Some(false),
 35            open_world_hint: Some(false),
 36            idempotent_hint: Some(false),
 37        }
 38    }
 39
 40    async fn run(
 41        &self,
 42        input: Self::Input,
 43        cx: &mut AsyncApp,
 44    ) -> Result<ToolResponse<Self::Output>> {
 45        let mut thread_rx = self.thread_rx.clone();
 46        let Some(thread) = thread_rx.recv().await?.upgrade() else {
 47            anyhow::bail!("Thread closed");
 48        };
 49
 50        let content = thread
 51            .update(cx, |thread, cx| {
 52                thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
 53            })?
 54            .await?;
 55
 56        let (new_content, diff) = cx
 57            .background_executor()
 58            .spawn(async move {
 59                let new_content = content.replace(&input.old_text, &input.new_text);
 60                if new_content == content {
 61                    return Err(anyhow::anyhow!("Failed to find `old_text`",));
 62                }
 63                let diff = unified_diff(&content, &new_content);
 64
 65                Ok((new_content, diff))
 66            })
 67            .await?;
 68
 69        thread
 70            .update(cx, |thread, cx| {
 71                thread.write_text_file(input.abs_path, new_content, cx)
 72            })?
 73            .await?;
 74
 75        Ok(ToolResponse {
 76            content: vec![ToolResponseContent::Text {
 77                text: MarkdownCodeBlock {
 78                    tag: "diff",
 79                    text: diff.as_str().trim_end_matches('\n'),
 80                }
 81                .to_string(),
 82            }],
 83            structured_content: (),
 84        })
 85    }
 86}
 87
 88#[cfg(test)]
 89mod tests {
 90    use std::rc::Rc;
 91
 92    use acp_thread::{AgentConnection, StubAgentConnection};
 93    use gpui::{Entity, TestAppContext};
 94    use indoc::indoc;
 95    use project::{FakeFs, Project};
 96    use serde_json::json;
 97    use settings::SettingsStore;
 98    use util::path;
 99
100    use super::*;
101
102    #[gpui::test]
103    async fn old_text_not_found(cx: &mut TestAppContext) {
104        let (_thread, tool) = init_test(cx).await;
105
106        let result = tool
107            .run(
108                EditToolParams {
109                    abs_path: path!("/root/file.txt").into(),
110                    old_text: "hi".into(),
111                    new_text: "bye".into(),
112                },
113                &mut cx.to_async(),
114            )
115            .await;
116
117        assert_eq!(result.unwrap_err().to_string(), "Failed to find `old_text`");
118    }
119
120    #[gpui::test]
121    async fn found_and_replaced(cx: &mut TestAppContext) {
122        let (_thread, tool) = init_test(cx).await;
123
124        let result = tool
125            .run(
126                EditToolParams {
127                    abs_path: path!("/root/file.txt").into(),
128                    old_text: "hello".into(),
129                    new_text: "hi".into(),
130                },
131                &mut cx.to_async(),
132            )
133            .await;
134
135        assert_eq!(
136            result.unwrap().content[0].text().unwrap(),
137            indoc! {
138                r"
139                ```diff
140                @@ -1,1 +1,1 @@
141                -hello
142                +hi
143                ```
144                "
145            }
146        );
147    }
148
149    async fn init_test(cx: &mut TestAppContext) -> (Entity<AcpThread>, EditTool) {
150        cx.update(|cx| {
151            let settings_store = SettingsStore::test(cx);
152            cx.set_global(settings_store);
153            language::init(cx);
154            Project::init_settings(cx);
155        });
156
157        let connection = Rc::new(StubAgentConnection::new());
158        let fs = FakeFs::new(cx.executor());
159        fs.insert_tree(
160            path!("/root"),
161            json!({
162                "file.txt": "hello"
163            }),
164        )
165        .await;
166        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
167        let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
168
169        let thread = cx
170            .update(|cx| connection.new_thread(project, path!("/test").as_ref(), cx))
171            .await
172            .unwrap();
173
174        thread_tx.send(thread.downgrade()).unwrap();
175
176        (thread, EditTool::new(thread_rx))
177    }
178}