Take a weak thread in EditFileTool to avoid cycle

Antonio Scandurra created

Change summary

crates/agent2/src/agent.rs                | 77 +++++++++++------------
crates/agent2/src/thread.rs               |  4 
crates/agent2/src/tools/edit_file_tool.rs | 83 +++++++++++++++---------
3 files changed, 89 insertions(+), 75 deletions(-)

Detailed changes

crates/agent2/src/agent.rs 🔗

@@ -574,7 +574,7 @@ impl NativeAgentConnection {
         thread.add_tool(CreateDirectoryTool::new(project.clone()));
         thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
         thread.add_tool(DiagnosticsTool::new(project.clone()));
-        thread.add_tool(EditFileTool::new(cx.entity()));
+        thread.add_tool(EditFileTool::new(cx.weak_entity()));
         thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
         thread.add_tool(FindPathTool::new(project.clone()));
         thread.add_tool(GrepTool::new(project.clone()));
@@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     fn load_thread(
         self: Rc<Self>,
         project: Entity<Project>,
-        cwd: &Path,
+        _cwd: &Path,
         session_id: acp::SessionId,
         cx: &mut App,
     ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
@@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
             let agent = self.0.clone();
 
             // Create Thread
-            let thread = agent.update(
-                cx,
-                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
-                    let configured_model = LanguageModelRegistry::global(cx)
-                        .update(cx, |registry, cx| {
-                            db_thread
-                                .model
-                                .and_then(|model| {
-                                    let model = SelectedModel {
-                                        provider: model.provider.clone().into(),
-                                        model: model.model.clone().into(),
-                                    };
-                                    registry.select_model(&model, cx)
-                                })
-                                .or_else(|| registry.default_model())
-                        })
-                        .context("no default model configured")?;
+            let thread = agent.update(cx, |agent, cx| {
+                let configured_model = LanguageModelRegistry::global(cx)
+                    .update(cx, |registry, cx| {
+                        db_thread
+                            .model
+                            .and_then(|model| {
+                                let model = SelectedModel {
+                                    provider: model.provider.clone().into(),
+                                    model: model.model.clone().into(),
+                                };
+                                registry.select_model(&model, cx)
+                            })
+                            .or_else(|| registry.default_model())
+                    })
+                    .context("no default model configured")?;
 
-                    let model = agent
-                        .models
-                        .model_from_id(&LanguageModels::model_id(&configured_model.model))
-                        .context("no model by id")?;
+                let model = agent
+                    .models
+                    .model_from_id(&LanguageModels::model_id(&configured_model.model))
+                    .context("no model by id")?;
 
-                    let thread = cx.new(|cx| {
-                        let mut thread = Thread::new(
-                            project.clone(),
-                            agent.project_context.clone(),
-                            agent.context_server_registry.clone(),
-                            action_log.clone(),
-                            agent.templates.clone(),
-                            model,
-                            cx,
-                        );
-                        Self::register_tools(&mut thread, project, action_log, cx);
-                        thread
-                    });
+                let thread = cx.new(|cx| {
+                    let mut thread = Thread::new(
+                        project.clone(),
+                        agent.project_context.clone(),
+                        agent.context_server_registry.clone(),
+                        action_log.clone(),
+                        agent.templates.clone(),
+                        model,
+                        cx,
+                    );
+                    Self::register_tools(&mut thread, project, action_log, cx);
+                    thread
+                });
 
-                    Ok(thread)
-                },
-            )??;
+                anyhow::Ok(thread)
+            })??;
 
             // Store the session
             agent.update(cx, |agent, cx| {
@@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
             })?;
 
             // we need to actually deserialize the DbThread.
-            todo!()
+            // todo!()
 
             Ok(acp_thread)
         })

crates/agent2/src/thread.rs 🔗

@@ -441,7 +441,7 @@ impl Thread {
         cx: &mut Context<Self>,
     ) -> Self {
         let profile_id = AgentSettings::get_global(cx).default_profile.clone();
-        let this = Self {
+        Self {
             messages: Vec::new(),
             completion_mode: CompletionMode::Normal,
             running_turn: None,
@@ -455,7 +455,7 @@ impl Thread {
             model,
             project,
             action_log,
-        };
+        }
     }
 
     pub fn project(&self) -> &Entity<Project> {

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
 use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
 use cloud_llm_client::CompletionIntent;
 use collections::HashSet;
-use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 use indoc::formatdoc;
 use language::ToPoint;
 use language::language_settings::{self, FormatOnSave};
@@ -122,11 +122,11 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 }
 
 pub struct EditFileTool {
-    thread: Entity<Thread>,
+    thread: WeakEntity<Thread>,
 }
 
 impl EditFileTool {
-    pub fn new(thread: Entity<Thread>) -> Self {
+    pub fn new(thread: WeakEntity<Thread>) -> Self {
         Self { thread }
     }
 
@@ -167,8 +167,11 @@ impl EditFileTool {
 
         // Check if path is inside the global config directory
         // First check if it's already inside project - if not, try to canonicalize
-        let thread = self.thread.read(cx);
-        let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
+        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
+            thread.project().read(cx).find_project_path(&input.path, cx)
+        }) else {
+            return Task::ready(Err(anyhow!("thread was dropped")));
+        };
 
         // If the path is inside the project, and it's not one of the above edge cases,
         // then no confirmation is necessary. Otherwise, confirmation is necessary.
@@ -221,7 +224,12 @@ impl AgentTool for EditFileTool {
         event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<Self::Output>> {
-        let project = self.thread.read(cx).project().clone();
+        let Ok(project) = self
+            .thread
+            .read_with(cx, |thread, _cx| thread.project().clone())
+        else {
+            return Task::ready(Err(anyhow!("thread was dropped")));
+        };
         let project_path = match resolve_path(&input, project.clone(), cx) {
             Ok(path) => path,
             Err(err) => return Task::ready(Err(anyhow!(err))),
@@ -237,17 +245,15 @@ impl AgentTool for EditFileTool {
             });
         }
 
-        let request = self.thread.update(cx, |thread, cx| {
-            thread.build_completion_request(CompletionIntent::ToolResults, cx)
-        });
-        let thread = self.thread.read(cx);
-        let model = thread.model().clone();
-        let action_log = thread.action_log().clone();
-
         let authorize = self.authorize(&input, &event_stream, cx);
         cx.spawn(async move |cx: &mut AsyncApp| {
             authorize.await?;
 
+            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
+                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
+                (request, thread.model().clone(), thread.action_log().clone())
+            })?;
+
             let edit_format = EditFormat::from_model(model.clone())?;
             let edit_agent = EditAgent::new(
                 model,
@@ -531,7 +537,11 @@ mod tests {
                     path: "root/nonexistent_file.txt".into(),
                     mode: EditFileMode::Edit,
                 };
-                Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             })
             .await;
         assert_eq!(
@@ -744,10 +754,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool {
-                    thread: thread.clone(),
-                })
-                .run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             });
 
             // Stream the unformatted content
@@ -800,7 +811,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             });
 
             // Stream the unformatted content
@@ -881,10 +896,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool {
-                    thread: thread.clone(),
-                })
-                .run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             });
 
             // Stream the content with trailing whitespace
@@ -932,10 +948,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool {
-                    thread: thread.clone(),
-                })
-                .run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             });
 
             // Stream the content with trailing whitespace
@@ -983,7 +1000,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
         fs.insert_tree("/root", json!({})).await;
 
         // Test 1: Path with .zed component should require confirmation
@@ -1114,7 +1131,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
 
         // Test global config paths - these should require confirmation if they exist and are outside the project
         let test_cases = vec![
@@ -1224,7 +1241,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
 
         // Test files in different worktrees
         let test_cases = vec![
@@ -1305,7 +1322,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
 
         // Test edge cases
         let test_cases = vec![
@@ -1389,7 +1406,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
 
         // Test different EditFileMode values
         let modes = vec![
@@ -1470,7 +1487,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
 
         assert_eq!(
             tool.initial_title(Err(json!({