@@ -1,6 +1,7 @@
use std::rc::Rc;
use acp_thread::{AgentConnection, LoadError};
+use agent::ThreadStore;
use agent_servers::{AgentServer, AgentServerDelegate};
use anyhow::Result;
use collections::HashMap;
@@ -53,16 +54,22 @@ impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
pub struct AgentConnectionStore {
project: Entity<Project>,
+ thread_store: Entity<ThreadStore>,
entries: HashMap<ExternalAgent, Entity<AgentConnectionEntry>>,
_subscriptions: Vec<Subscription>,
}
impl AgentConnectionStore {
- pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
+ pub fn new(
+ project: Entity<Project>,
+ thread_store: Entity<ThreadStore>,
+ cx: &mut Context<Self>,
+ ) -> Self {
let agent_server_store = project.read(cx).agent_server_store().clone();
let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
Self {
project,
+ thread_store,
entries: HashMap::default(),
_subscriptions: vec![subscription],
}
@@ -74,11 +81,15 @@ impl AgentConnectionStore {
pub fn request_connection(
&mut self,
- key: ExternalAgent,
- server: Rc<dyn AgentServer>,
+ agent: ExternalAgent,
cx: &mut Context<Self>,
) -> Entity<AgentConnectionEntry> {
- self.entries.get(&key).cloned().unwrap_or_else(|| {
+ self.entries.get(&agent).cloned().unwrap_or_else(|| {
+ let server = agent.server(
+ self.project.read(cx).fs().clone(),
+ self.thread_store.clone(),
+ );
+
let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
let connect_task = connect_task.shared();
@@ -86,10 +97,10 @@ impl AgentConnectionStore {
connect_task: connect_task.clone(),
});
- self.entries.insert(key.clone(), entry.clone());
+ self.entries.insert(agent.clone(), entry.clone());
cx.spawn({
- let key = key.clone();
+ let key = agent.clone();
let entry = entry.clone();
async move |this, cx| match connect_task.await {
Ok(connected_state) => {
@@ -123,7 +134,8 @@ impl AgentConnectionStore {
version.clone().into(),
));
});
- this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
+ this.update(cx, |this, _cx| this.entries.remove(&agent))
+ .ok();
}
}
}
@@ -1188,7 +1188,8 @@ impl AgentPanel {
language_registry,
text_thread_store,
prompt_store,
- connection_store: cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)),
+ connection_store: cx
+ .new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx)),
configuration: None,
configuration_subscription: None,
focus_handle: cx.focus_handle(),
@@ -1337,12 +1338,11 @@ impl AgentPanel {
) {
let agent = ExternalAgent::NativeAgent;
- let server = agent.server(self.fs.clone(), self.thread_store.clone());
let session_id = action.from_session_id.clone();
- let entry = self.connection_store.update(cx, |store, cx| {
- store.request_connection(agent.clone(), server, cx)
- });
+ let entry = self
+ .connection_store
+ .update(cx, |store, cx| store.request_connection(agent.clone(), cx));
let connect_task = entry.read(cx).wait_for_connection();
cx.spawn_in(window, async move |this, cx| {
@@ -654,9 +654,8 @@ impl ConnectionView {
.or_else(|| worktree_roots.first().cloned())
.unwrap_or_else(|| paths::home_dir().as_path().into());
- let connection_entry = connection_store.update(cx, |store, cx| {
- store.request_connection(connection_key, agent.clone(), cx)
- });
+ let connection_entry =
+ connection_store.update(cx, |store, cx| store.request_connection(connection_key, cx));
let connection_entry_subscription =
cx.subscribe(&connection_entry, |this, _entry, event, cx| match event {
@@ -2910,8 +2909,9 @@ pub(crate) mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3022,8 +3022,9 @@ pub(crate) mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3079,8 +3080,9 @@ pub(crate) mod tests {
let captured_cwd = connection.captured_cwd.clone();
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let _thread_view = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3134,8 +3136,9 @@ pub(crate) mod tests {
let captured_cwd = connection.captured_cwd.clone();
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let _thread_view = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3189,8 +3192,9 @@ pub(crate) mod tests {
let captured_cwd = connection.captured_cwd.clone();
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let _thread_view = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3505,8 +3509,9 @@ pub(crate) mod tests {
// Set up thread view in workspace 1
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project1.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project1.clone(), thread_store.clone(), cx))
+ });
let agent = StubAgentServer::default_response();
let thread_view = cx.update(|window, cx| {
@@ -3726,8 +3731,9 @@ pub(crate) mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let agent_key = ExternalAgent::Custom {
name: "Test".into(),
@@ -4470,8 +4476,9 @@ pub(crate) mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
- let connection_store =
- cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+ let connection_store = cx.update(|_window, cx| {
+ cx.new(|cx| AgentConnectionStore::new(project.clone(), thread_store.clone(), cx))
+ });
let connection = Rc::new(StubAgentConnection::new());
let thread_view = cx.update(|window, cx| {