@@ -574,7 +574,7 @@ impl NativeAgentConnection {
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
- thread.add_tool(EditFileTool::new(cx.entity()));
+ thread.add_tool(EditFileTool::new(cx.weak_entity()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
@@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn load_thread(
self: Rc<Self>,
project: Entity<Project>,
- cwd: &Path,
+ _cwd: &Path,
session_id: acp::SessionId,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
@@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let agent = self.0.clone();
// Create Thread
- let thread = agent.update(
- cx,
- |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
- let configured_model = LanguageModelRegistry::global(cx)
- .update(cx, |registry, cx| {
- db_thread
- .model
- .and_then(|model| {
- let model = SelectedModel {
- provider: model.provider.clone().into(),
- model: model.model.clone().into(),
- };
- registry.select_model(&model, cx)
- })
- .or_else(|| registry.default_model())
- })
- .context("no default model configured")?;
+ let thread = agent.update(cx, |agent, cx| {
+ let configured_model = LanguageModelRegistry::global(cx)
+ .update(cx, |registry, cx| {
+ db_thread
+ .model
+ .and_then(|model| {
+ let model = SelectedModel {
+ provider: model.provider.clone().into(),
+ model: model.model.clone().into(),
+ };
+ registry.select_model(&model, cx)
+ })
+ .or_else(|| registry.default_model())
+ })
+ .context("no default model configured")?;
- let model = agent
- .models
- .model_from_id(&LanguageModels::model_id(&configured_model.model))
- .context("no model by id")?;
+ let model = agent
+ .models
+ .model_from_id(&LanguageModels::model_id(&configured_model.model))
+ .context("no model by id")?;
- let thread = cx.new(|cx| {
- let mut thread = Thread::new(
- project.clone(),
- agent.project_context.clone(),
- agent.context_server_registry.clone(),
- action_log.clone(),
- agent.templates.clone(),
- model,
- cx,
- );
- Self::register_tools(&mut thread, project, action_log, cx);
- thread
- });
+ let thread = cx.new(|cx| {
+ let mut thread = Thread::new(
+ project.clone(),
+ agent.project_context.clone(),
+ agent.context_server_registry.clone(),
+ action_log.clone(),
+ agent.templates.clone(),
+ model,
+ cx,
+ );
+ Self::register_tools(&mut thread, project, action_log, cx);
+ thread
+ });
- Ok(thread)
- },
- )??;
+ anyhow::Ok(thread)
+ })??;
// Store the session
agent.update(cx, |agent, cx| {
@@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})?;
// we need to actually deserialize the DbThread.
- todo!()
+ // todo!()
Ok(acp_thread)
})
@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
-use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
@@ -122,11 +122,11 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct EditFileTool {
- thread: Entity<Thread>,
+ thread: WeakEntity<Thread>,
}
impl EditFileTool {
- pub fn new(thread: Entity<Thread>) -> Self {
+ pub fn new(thread: WeakEntity<Thread>) -> Self {
Self { thread }
}
@@ -167,8 +167,11 @@ impl EditFileTool {
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
- let thread = self.thread.read(cx);
- let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
+ let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
+ thread.project().read(cx).find_project_path(&input.path, cx)
+ }) else {
+ return Task::ready(Err(anyhow!("thread was dropped")));
+ };
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@@ -221,7 +224,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
- let project = self.thread.read(cx).project().clone();
+ let Ok(project) = self
+ .thread
+ .read_with(cx, |thread, _cx| thread.project().clone())
+ else {
+ return Task::ready(Err(anyhow!("thread was dropped")));
+ };
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
@@ -237,17 +245,15 @@ impl AgentTool for EditFileTool {
});
}
- let request = self.thread.update(cx, |thread, cx| {
- thread.build_completion_request(CompletionIntent::ToolResults, cx)
- });
- let thread = self.thread.read(cx);
- let model = thread.model().clone();
- let action_log = thread.action_log().clone();
-
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
+ let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
+ let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
+ (request, thread.model().clone(), thread.action_log().clone())
+ })?;
+
let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new(
model,
@@ -531,7 +537,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
- Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade())).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
})
.await;
assert_eq!(
@@ -744,10 +754,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool {
- thread: thread.clone(),
- })
- .run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade())).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
});
// Stream the unformatted content
@@ -800,7 +811,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade())).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
});
// Stream the unformatted content
@@ -881,10 +896,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool {
- thread: thread.clone(),
- })
- .run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade())).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
});
// Stream the content with trailing whitespace
@@ -932,10 +948,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool {
- thread: thread.clone(),
- })
- .run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade())).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
});
// Stream the content with trailing whitespace
@@ -983,7 +1000,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade()));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@@ -1114,7 +1131,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@@ -1224,7 +1241,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test files in different worktrees
let test_cases = vec![
@@ -1305,7 +1322,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test edge cases
let test_cases = vec![
@@ -1389,7 +1406,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test different EditFileMode values
let modes = vec![
@@ -1470,7 +1487,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade()));
assert_eq!(
tool.initial_title(Err(json!({