diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 65f6e38c56c4cb0295a196756bda210b20445ee6..074dffe1dc7db8083e01367ae17e9a1accc555f5 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -574,7 +574,7 @@ impl NativeAgentConnection { thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(EditFileTool::new(cx.weak_entity())); thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(GrepTool::new(project.clone())); @@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn load_thread( self: Rc, project: Entity, - cwd: &Path, + _cwd: &Path, session_id: acp::SessionId, cx: &mut App, ) -> Task>> { @@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let agent = self.0.clone(); // Create Thread - let thread = agent.update( - cx, - |agent, cx: &mut gpui::Context| -> Result<_> { - let configured_model = LanguageModelRegistry::global(cx) - .update(cx, |registry, cx| { - db_thread - .model - .and_then(|model| { - let model = SelectedModel { - provider: model.provider.clone().into(), - model: model.model.clone().into(), - }; - registry.select_model(&model, cx) - }) - .or_else(|| registry.default_model()) - }) - .context("no default model configured")?; + let thread = agent.update(cx, |agent, cx| { + let configured_model = LanguageModelRegistry::global(cx) + .update(cx, |registry, cx| { + db_thread + .model + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.clone().into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + }) + .context("no default model configured")?; - let model = agent - .models - .model_from_id(&LanguageModels::model_id(&configured_model.model)) - .context("no model by id")?; + let model = agent + .models + .model_from_id(&LanguageModels::model_id(&configured_model.model)) + .context("no model by id")?; - let thread = cx.new(|cx| { - let mut thread = Thread::new( - project.clone(), - agent.project_context.clone(), - agent.context_server_registry.clone(), - action_log.clone(), - agent.templates.clone(), - model, - cx, - ); - Self::register_tools(&mut thread, project, action_log, cx); - thread - }); + let thread = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + model, + cx, + ); + Self::register_tools(&mut thread, project, action_log, cx); + thread + }); - Ok(thread) - }, - )??; + anyhow::Ok(thread) + })??; // Store the session agent.update(cx, |agent, cx| { @@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; // we need to actually deserialize the DbThread. - todo!() + // todo!() Ok(acp_thread) }) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 33f0208ab72fcbcce3417b06df8f0c65e32c9af4..784477d677a2f4af78b39757c059e471c453e2ec 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -441,7 +441,7 @@ impl Thread { cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - let this = Self { + Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, @@ -455,7 +455,7 @@ impl Thread { model, project, action_log, - }; + } } pub fn project(&self) -> &Entity { diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index c77b9f6a69bededaa632333b40c85d73bf4e8a92..6462308918dfd299f64d3f676a6552d6c6cc9c11 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; use language::ToPoint; use language::language_settings::{self, FormatOnSave}; @@ -122,11 +122,11 @@ impl From for LanguageModelToolResultContent { } pub struct EditFileTool { - thread: Entity, + thread: WeakEntity, } impl EditFileTool { - pub fn new(thread: Entity) -> Self { + pub fn new(thread: WeakEntity) -> Self { Self { thread } } @@ -167,8 +167,11 @@ impl EditFileTool { // Check if path is inside the global config directory // First check if it's already inside project - if not, try to canonicalize - let thread = self.thread.read(cx); - let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + let Ok(project_path) = self.thread.read_with(cx, |thread, cx| { + thread.project().read(cx).find_project_path(&input.path, cx) + }) else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; // If the path is inside the project, and it's not one of the above edge cases, // then no confirmation is necessary. Otherwise, confirmation is necessary. @@ -221,7 +224,12 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let project = self.thread.read(cx).project().clone(); + let Ok(project) = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; let project_path = match resolve_path(&input, project.clone(), cx) { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), @@ -237,17 +245,15 @@ impl AgentTool for EditFileTool { }); } - let request = self.thread.update(cx, |thread, cx| { - thread.build_completion_request(CompletionIntent::ToolResults, cx) - }); - let thread = self.thread.read(cx); - let model = thread.model().clone(); - let action_log = thread.action_log().clone(); - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { authorize.await?; + let (request, model, action_log) = self.thread.update(cx, |thread, cx| { + let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); + (request, thread.model().clone(), thread.action_log().clone()) + })?; + let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( model, @@ -531,7 +537,11 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!( @@ -744,10 +754,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -800,7 +811,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -881,10 +896,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -932,10 +948,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -983,7 +1000,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1114,7 +1131,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1224,7 +1241,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test files in different worktrees let test_cases = vec![ @@ -1305,7 +1322,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test edge cases let test_cases = vec![ @@ -1389,7 +1406,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test different EditFileMode values let modes = vec![ @@ -1470,7 +1487,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); assert_eq!( tool.initial_title(Err(json!({