@@ -37,7 +37,8 @@ use futures::channel::{mpsc, oneshot};
use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _, future};
use gpui::{
- App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
+ App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
+ WeakEntity,
};
use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
@@ -65,12 +66,22 @@ pub struct RulesLoadingError {
pub message: SharedString,
}
+struct ProjectState {
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ project_context_needs_refresh: watch::Sender<()>,
+ _maintain_project_context: Task<Result<()>>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ _subscriptions: Vec<Subscription>,
+}
+
/// Holds both the internal Thread and the AcpThread for a session
struct Session {
/// The internal thread that processes messages
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: Entity<acp_thread::AcpThread>,
+ project_id: EntityId,
pending_save: Task<()>,
_subscriptions: Vec<Subscription>,
}
@@ -235,79 +246,47 @@ pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
thread_store: Entity<ThreadStore>,
- /// Shared project context for all threads
- project_context: Entity<ProjectContext>,
- project_context_needs_refresh: watch::Sender<()>,
- _maintain_project_context: Task<Result<()>>,
- context_server_registry: Entity<ContextServerRegistry>,
+ /// Project-specific state keyed by project EntityId
+ projects: HashMap<EntityId, ProjectState>,
/// Shared templates for all threads
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
- project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
impl NativeAgent {
- pub async fn new(
- project: Entity<Project>,
+ pub fn new(
thread_store: Entity<ThreadStore>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
- cx: &mut AsyncApp,
- ) -> Result<Entity<NativeAgent>> {
+ cx: &mut App,
+ ) -> Entity<NativeAgent> {
log::debug!("Creating new NativeAgent");
- let project_context = cx
- .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
- .await;
-
- Ok(cx.new(|cx| {
- let context_server_store = project.read(cx).context_server_store();
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
-
- let mut subscriptions = vec![
- cx.subscribe(&project, Self::handle_project_event),
- cx.subscribe(
- &LanguageModelRegistry::global(cx),
- Self::handle_models_updated_event,
- ),
- cx.subscribe(
- &context_server_store,
- Self::handle_context_server_store_updated,
- ),
- cx.subscribe(
- &context_server_registry,
- Self::handle_context_server_registry_event,
- ),
- ];
+ cx.new(|cx| {
+ let mut subscriptions = vec![cx.subscribe(
+ &LanguageModelRegistry::global(cx),
+ Self::handle_models_updated_event,
+ )];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
- let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
- watch::channel(());
Self {
sessions: HashMap::default(),
thread_store,
- project_context: cx.new(|_| project_context),
- project_context_needs_refresh: project_context_needs_refresh_tx,
- _maintain_project_context: cx.spawn(async move |this, cx| {
- Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
- }),
- context_server_registry,
+ projects: HashMap::default(),
templates,
models: LanguageModels::new(cx),
- project,
prompt_store,
fs,
_subscriptions: subscriptions,
}
- }))
+ })
}
fn new_session(
@@ -315,10 +294,10 @@ impl NativeAgent {
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
- // Create Thread
- // Fetch default model from registry settings
+ let project_id = self.get_or_create_project_state(&project, cx);
+ let project_state = &self.projects[&project_id];
+
let registry = LanguageModelRegistry::read_global(cx);
- // Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
@@ -328,21 +307,22 @@ impl NativeAgent {
});
let thread = cx.new(|cx| {
Thread::new(
- project.clone(),
- self.project_context.clone(),
- self.context_server_registry.clone(),
+ project,
+ project_state.project_context.clone(),
+ project_state.context_server_registry.clone(),
self.templates.clone(),
default_model,
cx,
)
});
- self.register_session(thread, cx)
+ self.register_session(thread, project_id, cx)
}
fn register_session(
&mut self,
thread_handle: Entity<Thread>,
+ project_id: EntityId,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
@@ -405,12 +385,13 @@ impl NativeAgent {
Session {
thread: thread_handle,
acp_thread: acp_thread.clone(),
+ project_id,
_subscriptions: subscriptions,
pending_save: Task::ready(()),
},
);
- self.update_available_commands(cx);
+ self.update_available_commands_for_project(project_id, cx);
acp_thread
}
@@ -419,19 +400,102 @@ impl NativeAgent {
&self.models
}
+ fn get_or_create_project_state(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> EntityId {
+ let project_id = project.entity_id();
+ if self.projects.contains_key(&project_id) {
+ return project_id;
+ }
+
+ let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
+ self.register_project_with_initial_context(project.clone(), project_context, cx);
+ if let Some(state) = self.projects.get_mut(&project_id) {
+ state.project_context_needs_refresh.send(()).ok();
+ }
+ project_id
+ }
+
+ fn register_project_with_initial_context(
+ &mut self,
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ cx: &mut Context<Self>,
+ ) {
+ let project_id = project.entity_id();
+
+ let context_server_store = project.read(cx).context_server_store();
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subscriptions = vec![
+ cx.subscribe(&project, Self::handle_project_event),
+ cx.subscribe(
+ &context_server_store,
+ Self::handle_context_server_store_updated,
+ ),
+ cx.subscribe(
+ &context_server_registry,
+ Self::handle_context_server_registry_event,
+ ),
+ ];
+
+ let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
+ watch::channel(());
+
+ self.projects.insert(
+ project_id,
+ ProjectState {
+ project,
+ project_context,
+ project_context_needs_refresh: project_context_needs_refresh_tx,
+ _maintain_project_context: cx.spawn(async move |this, cx| {
+ Self::maintain_project_context(
+ this,
+ project_id,
+ project_context_needs_refresh_rx,
+ cx,
+ )
+ .await
+ }),
+ context_server_registry,
+ _subscriptions: subscriptions,
+ },
+ );
+ }
+
+ fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
+ self.sessions
+ .get(session_id)
+ .and_then(|session| self.projects.get(&session.project_id))
+ }
+
async fn maintain_project_context(
this: WeakEntity<Self>,
+ project_id: EntityId,
mut needs_refresh: watch::Receiver<()>,
cx: &mut AsyncApp,
) -> Result<()> {
while needs_refresh.changed().await.is_ok() {
let project_context = this
.update(cx, |this, cx| {
- Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
- })?
+ let state = this
+ .projects
+ .get(&project_id)
+ .context("project state not found")?;
+ anyhow::Ok(Self::build_project_context(
+ &state.project,
+ this.prompt_store.as_ref(),
+ cx,
+ ))
+ })??
.await;
this.update(cx, |this, cx| {
- this.project_context = cx.new(|_| project_context);
+ if let Some(state) = this.projects.get_mut(&project_id) {
+ state.project_context = cx.new(|_| project_context);
+ }
})?;
}
@@ -620,13 +684,17 @@ impl NativeAgent {
fn handle_project_event(
&mut self,
- _project: Entity<Project>,
+ project: Entity<Project>,
event: &project::Event,
_cx: &mut Context<Self>,
) {
+ let project_id = project.entity_id();
+ let Some(state) = self.projects.get_mut(&project_id) else {
+ return;
+ };
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
- self.project_context_needs_refresh.send(()).ok();
+ state.project_context_needs_refresh.send(()).ok();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
@@ -634,7 +702,7 @@ impl NativeAgent {
.iter()
.any(|name| path.as_ref() == RelPath::unix(name).unwrap())
}) {
- self.project_context_needs_refresh.send(()).ok();
+ state.project_context_needs_refresh.send(()).ok();
}
}
_ => {}
@@ -647,7 +715,9 @@ impl NativeAgent {
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
- self.project_context_needs_refresh.send(()).ok();
+ for state in self.projects.values_mut() {
+ state.project_context_needs_refresh.send(()).ok();
+ }
}
fn handle_models_updated_event(
@@ -677,30 +747,52 @@ impl NativeAgent {
fn handle_context_server_store_updated(
&mut self,
- _store: Entity<project::context_server_store::ContextServerStore>,
+ store: Entity<project::context_server_store::ContextServerStore>,
_event: &project::context_server_store::ServerStatusChangedEvent,
cx: &mut Context<Self>,
) {
- self.update_available_commands(cx);
+ let project_id = self.projects.iter().find_map(|(id, state)| {
+ if *state.context_server_registry.read(cx).server_store() == store {
+ Some(*id)
+ } else {
+ None
+ }
+ });
+ if let Some(project_id) = project_id {
+ self.update_available_commands_for_project(project_id, cx);
+ }
}
fn handle_context_server_registry_event(
&mut self,
- _registry: Entity<ContextServerRegistry>,
+ registry: Entity<ContextServerRegistry>,
event: &ContextServerRegistryEvent,
cx: &mut Context<Self>,
) {
match event {
ContextServerRegistryEvent::ToolsChanged => {}
ContextServerRegistryEvent::PromptsChanged => {
- self.update_available_commands(cx);
+ let project_id = self.projects.iter().find_map(|(id, state)| {
+ if state.context_server_registry == registry {
+ Some(*id)
+ } else {
+ None
+ }
+ });
+ if let Some(project_id) = project_id {
+ self.update_available_commands_for_project(project_id, cx);
+ }
}
}
}
- fn update_available_commands(&self, cx: &mut Context<Self>) {
- let available_commands = self.build_available_commands(cx);
+ fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
+ let available_commands =
+ Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
for session in self.sessions.values() {
+ if session.project_id != project_id {
+ continue;
+ }
session.acp_thread.update(cx, |thread, cx| {
thread
.handle_session_update(
@@ -714,8 +806,14 @@ impl NativeAgent {
}
}
- fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
- let registry = self.context_server_registry.read(cx);
+ fn build_available_commands_for_project(
+ project_state: Option<&ProjectState>,
+ cx: &App,
+ ) -> Vec<acp::AvailableCommand> {
+ let Some(state) = project_state else {
+ return vec![];
+ };
+ let registry = state.context_server_registry.read(cx);
let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
for context_server_prompt in registry.prompts() {
@@ -769,8 +867,10 @@ impl NativeAgent {
pub fn load_thread(
&mut self,
id: acp::SessionId,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> {
+ let project_id = self.get_or_create_project_state(&project, cx);
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
@@ -780,41 +880,48 @@ impl NativeAgent {
.with_context(|| format!("no thread found with ID: {id:?}"))?;
this.update(cx, |this, cx| {
+ let project_state = this
+ .projects
+ .get(&project_id)
+ .context("project state not found")?;
let summarization_model = LanguageModelRegistry::read_global(cx)
.thread_summary_model()
.map(|c| c.model);
- cx.new(|cx| {
+ Ok(cx.new(|cx| {
let mut thread = Thread::from_db(
id.clone(),
db_thread,
- this.project.clone(),
- this.project_context.clone(),
- this.context_server_registry.clone(),
+ project_state.project.clone(),
+ project_state.project_context.clone(),
+ project_state.context_server_registry.clone(),
this.templates.clone(),
cx,
);
thread.set_summarization_model(summarization_model, cx);
thread
- })
- })
+ }))
+ })?
})
}
pub fn open_thread(
&mut self,
id: acp::SessionId,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AcpThread>>> {
if let Some(session) = self.sessions.get(&id) {
return Task::ready(Ok(session.acp_thread.clone()));
}
- let task = self.load_thread(id, cx);
+ let project_id = self.get_or_create_project_state(&project, cx);
+ let task = self.load_thread(id, project, cx);
cx.spawn(async move |this, cx| {
let thread = task.await?;
- let acp_thread =
- this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
+ let acp_thread = this.update(cx, |this, cx| {
+ this.register_session(thread.clone(), project_id, cx)
+ })?;
let events = thread.update(cx, |thread, cx| thread.replay(cx));
cx.update(|cx| {
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
@@ -827,9 +934,10 @@ impl NativeAgent {
pub fn thread_summary(
&mut self,
id: acp::SessionId,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<SharedString>> {
- let thread = self.open_thread(id.clone(), cx);
+ let thread = self.open_thread(id.clone(), project, cx);
cx.spawn(async move |this, cx| {
let acp_thread = thread.await?;
let result = this
@@ -857,8 +965,13 @@ impl NativeAgent {
return;
};
+ let project_id = session.project_id;
+ let Some(state) = self.projects.get(&project_id) else {
+ return;
+ };
+
let folder_paths = PathList::new(
- &self
+ &state
.project
.read(cx)
.visible_worktrees(cx)
@@ -889,15 +1002,22 @@ impl NativeAgent {
fn send_mcp_prompt(
&self,
message_id: UserMessageId,
- session_id: agent_client_protocol::SessionId,
+ session_id: acp::SessionId,
prompt_name: String,
server_id: ContextServerId,
arguments: HashMap<String, String>,
original_content: Vec<acp::ContentBlock>,
cx: &mut Context<Self>,
) -> Task<Result<acp::PromptResponse>> {
- let server_store = self.context_server_registry.read(cx).server_store().clone();
- let path_style = self.project.read(cx).path_style(cx);
+ let Some(state) = self.session_project_state(&session_id) else {
+ return Task::ready(Err(anyhow!("Project state not found for session")));
+ };
+ let server_store = state
+ .context_server_registry
+ .read(cx)
+ .server_store()
+ .clone();
+ let path_style = state.project.read(cx).path_style(cx);
cx.spawn(async move |this, cx| {
let prompt =
@@ -996,8 +1116,14 @@ impl NativeAgentConnection {
.map(|session| session.thread.clone())
}
- pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
- self.0.update(cx, |this, cx| this.load_thread(id, cx))
+ pub fn load_thread(
+ &self,
+ id: acp::SessionId,
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Task<Result<Entity<Thread>>> {
+ self.0
+ .update(cx, |this, cx| this.load_thread(id, project, cx))
}
fn run_turn(
@@ -1279,13 +1405,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn load_session(
self: Rc<Self>,
session_id: acp::SessionId,
- _project: Entity<Project>,
+ project: Entity<Project>,
_cwd: &Path,
_title: Option<SharedString>,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
self.0
- .update(cx, |agent, cx| agent.open_thread(session_id, cx))
+ .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
}
fn supports_close_session(&self) -> bool {
@@ -1294,7 +1420,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
self.0.update(cx, |agent, _cx| {
+ let project_id = agent.sessions.get(session_id).map(|s| s.project_id);
agent.sessions.remove(session_id);
+
+ if let Some(project_id) = project_id {
+ let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
+ if !has_remaining {
+ agent.projects.remove(&project_id);
+ }
+ }
});
Task::ready(Ok(()))
}
@@ -1325,8 +1459,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
+ let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
+ return Task::ready(Err(anyhow::anyhow!("Session not found")));
+ };
+
if let Some(parsed_command) = Command::parse(¶ms.prompt) {
- let registry = self.0.read(cx).context_server_registry.read(cx);
+ let registry = project_state.context_server_registry.read(cx);
let explicit_server_id = parsed_command
.explicit_server_id
@@ -1362,10 +1500,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx,
)
});
- };
+ }
};
- let path_style = self.0.read(cx).project.read(cx).path_style(cx);
+ let path_style = project_state.project.read(cx).path_style(cx);
self.run_turn(session_id, cx, move |thread, cx| {
let content: Vec<UserMessageContent> = params
@@ -1406,7 +1544,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn truncate(
&self,
- session_id: &agent_client_protocol::SessionId,
+ session_id: &acp::SessionId,
cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
self.0.read_with(cx, |agent, _cx| {
@@ -1611,6 +1749,7 @@ impl NativeThreadEnvironment {
};
let parent_thread = parent_thread_entity.read(cx);
let current_depth = parent_thread.depth();
+ let parent_session_id = parent_thread.id().clone();
if current_depth >= MAX_SUBAGENT_DEPTH {
return Err(anyhow!(
@@ -1627,9 +1766,16 @@ impl NativeThreadEnvironment {
let session_id = subagent_thread.read(cx).id().clone();
- let acp_thread = self.agent.update(cx, |agent, cx| {
- agent.register_session(subagent_thread.clone(), cx)
- })?;
+ let acp_thread = self
+ .agent
+ .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
+ let project_id = agent
+ .sessions
+ .get(&parent_session_id)
+ .map(|s| s.project_id)
+ .context("parent session not found")?;
+ Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
+ })??;
let depth = current_depth + 1;
@@ -1955,18 +2101,21 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
+
+ // Creating a session registers the project and triggers context building.
+ let connection = NativeAgentConnection(agent.clone());
+ let _acp_thread = cx
+ .update(|cx| Rc::new(connection).new_session(project.clone(), Path::new("/"), cx))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
agent.read_with(cx, |agent, cx| {
- assert_eq!(agent.project_context.read(cx).worktrees, vec![])
+ let project_id = project.entity_id();
+ let state = agent.projects.get(&project_id).unwrap();
+ assert_eq!(state.project_context.read(cx).worktrees, vec![])
});
let worktree = project
@@ -1975,8 +2124,10 @@ mod internal_tests {
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
+ let project_id = project.entity_id();
+ let state = agent.projects.get(&project_id).unwrap();
assert_eq!(
- agent.project_context.read(cx).worktrees,
+ state.project_context.read(cx).worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
@@ -1989,12 +2140,14 @@ mod internal_tests {
fs.insert_file("/a/.rules", Vec::new()).await;
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
+ let project_id = project.entity_id();
+ let state = agent.projects.get(&project_id).unwrap();
let rules_entry = worktree
.read(cx)
.entry_for_path(rel_path(".rules"))
.unwrap();
assert_eq!(
- agent.project_context.read(cx).worktrees,
+ state.project_context.read(cx).worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
@@ -2015,18 +2168,10 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let connection = NativeAgentConnection(
- NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap(),
- );
+ let connection =
+ NativeAgentConnection(cx.update(|cx| {
+ NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
+ }));
// Create a thread/session
let acp_thread = cx
@@ -2095,16 +2240,8 @@ mod internal_tests {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
// Create the agent and connection
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
@@ -2196,16 +2333,8 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
let acp_thread = cx
@@ -2288,16 +2417,9 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
// Register a thinking model.
@@ -2371,7 +2493,9 @@ mod internal_tests {
// Reload the thread and verify thinking_enabled is still true.
let reloaded_acp_thread = agent
- .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
+ .update(cx, |agent, cx| {
+ agent.open_thread(session_id.clone(), project.clone(), cx)
+ })
.await
.unwrap();
let reloaded_thread = agent.read_with(cx, |agent, _| {
@@ -2394,16 +2518,9 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
// Register a model where id() != name(), like real Anthropic models
@@ -2478,7 +2595,9 @@ mod internal_tests {
// Reload the thread and verify the model was preserved.
let reloaded_acp_thread = agent
- .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
+ .update(cx, |agent, cx| {
+ agent.open_thread(session_id.clone(), project.clone(), cx)
+ })
.await
.unwrap();
let reloaded_thread = agent.read_with(cx, |agent, _| {
@@ -2513,16 +2632,9 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -2642,7 +2754,9 @@ mod internal_tests {
)]
);
let acp_thread = agent
- .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
+ .update(cx, |agent, cx| {
+ agent.open_thread(session_id.clone(), project.clone(), cx)
+ })
.await
.unwrap();
acp_thread.read_with(cx, |thread, cx| {
@@ -3181,16 +3181,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
// Create agent and connection
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- templates.clone(),
- None,
- fake_fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx
+ .update(|cx| NativeAgent::new(thread_store, templates.clone(), None, fake_fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
// Create a thread using new_thread
@@ -4388,16 +4380,9 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -4530,16 +4515,9 @@ async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppCon
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -4685,16 +4663,9 @@ async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAp
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -4822,16 +4793,9 @@ async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -5201,16 +5165,9 @@ async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -5334,16 +5291,9 @@ async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mu
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@@ -5515,16 +5465,9 @@ async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx