@@ -514,11 +514,14 @@ impl ThreadStore {
}
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
- cx.subscribe(
- &self.project.read(cx).context_server_store(),
- Self::handle_context_server_event,
- )
- .detach();
+ let context_server_store = self.project.read(cx).context_server_store();
+ cx.subscribe(&context_server_store, Self::handle_context_server_event)
+ .detach();
+
+ // Check for any servers that were already running before the handler was registered
+ for server in context_server_store.read(cx).running_servers() {
+ self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
+ }
}
fn handle_context_server_event(
@@ -533,55 +536,7 @@ impl ThreadStore {
match status {
ContextServerStatus::Starting => {}
ContextServerStatus::Running => {
- if let Some(server) =
- context_server_store.read(cx).get_running_server(server_id)
- {
- let context_server_manager = context_server_store.clone();
- cx.spawn({
- let server = server.clone();
- let server_id = server_id.clone();
- async move |this, cx| {
- let Some(protocol) = server.client() else {
- return;
- };
-
- if protocol.capable(context_server::protocol::ServerCapability::Tools) {
- if let Some(tools) = protocol.list_tools().await.log_err() {
- let tool_ids = tool_working_set
- .update(cx, |tool_working_set, _| {
- tools
- .tools
- .into_iter()
- .map(|tool| {
- log::info!(
- "registering context server tool: {:?}",
- tool.name
- );
- tool_working_set.insert(Arc::new(
- ContextServerTool::new(
- context_server_manager.clone(),
- server.id(),
- tool,
- ),
- ))
- })
- .collect::<Vec<_>>()
- })
- .log_err();
-
- if let Some(tool_ids) = tool_ids {
- this.update(cx, |this, _| {
- this.context_server_tool_ids
- .insert(server_id, tool_ids);
- })
- .log_err();
- }
- }
- }
- }
- })
- .detach();
- }
+ self.load_context_server_tools(server_id.clone(), context_server_store, cx);
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
@@ -594,6 +549,52 @@ impl ThreadStore {
}
}
}
+
+ fn load_context_server_tools(
+ &self,
+ server_id: ContextServerId,
+ context_server_store: Entity<ContextServerStore>,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
+ return;
+ };
+ let tool_working_set = self.tools.clone();
+ cx.spawn(async move |this, cx| {
+ let Some(protocol) = server.client() else {
+ return;
+ };
+
+ if protocol.capable(context_server::protocol::ServerCapability::Tools) {
+ if let Some(tools) = protocol.list_tools().await.log_err() {
+ let tool_ids = tool_working_set
+ .update(cx, |tool_working_set, _| {
+ tools
+ .tools
+ .into_iter()
+ .map(|tool| {
+ log::info!("registering context server tool: {:?}", tool.name);
+ tool_working_set.insert(Arc::new(ContextServerTool::new(
+ context_server_store.clone(),
+ server.id(),
+ tool,
+ )))
+ })
+ .collect::<Vec<_>>()
+ })
+ .log_err();
+
+ if let Some(tool_ids) = tool_ids {
+ this.update(cx, |this, _| {
+ this.context_server_tool_ids.insert(server_id, tool_ids);
+ })
+ .log_err();
+ }
+ }
+ }
+ })
+ .detach();
+ }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -809,74 +809,37 @@ impl ContextStore {
}
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
- cx.subscribe(
- &self.project.read(cx).context_server_store(),
- Self::handle_context_server_event,
- )
- .detach();
+ let context_server_store = self.project.read(cx).context_server_store();
+ cx.subscribe(&context_server_store, Self::handle_context_server_event)
+ .detach();
+
+ // Check for any servers that were already running before the handler was registered
+ for server in context_server_store.read(cx).running_servers() {
+ self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
+ }
}
fn handle_context_server_event(
&mut self,
- context_server_manager: Entity<ContextServerStore>,
+ context_server_store: Entity<ContextServerStore>,
event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
- let slash_command_working_set = self.slash_commands.clone();
match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
ContextServerStatus::Running => {
- if let Some(server) = context_server_manager
- .read(cx)
- .get_running_server(server_id)
- {
- let context_server_manager = context_server_manager.clone();
- cx.spawn({
- let server = server.clone();
- let server_id = server_id.clone();
- async move |this, cx| {
- let Some(protocol) = server.client() else {
- return;
- };
-
- if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
- if let Some(prompts) = protocol.list_prompts().await.log_err() {
- let slash_command_ids = prompts
- .into_iter()
- .filter(assistant_slash_commands::acceptable_prompt)
- .map(|prompt| {
- log::info!(
- "registering context server command: {:?}",
- prompt.name
- );
- slash_command_working_set.insert(Arc::new(
- assistant_slash_commands::ContextServerSlashCommand::new(
- context_server_manager.clone(),
- server.id(),
- prompt,
- ),
- ))
- })
- .collect::<Vec<_>>();
-
- this.update( cx, |this, _cx| {
- this.context_server_slash_command_ids
- .insert(server_id.clone(), slash_command_ids);
- })
- .log_err();
- }
- }
- }
- })
- .detach();
- }
+ self.load_context_server_slash_commands(
+ server_id.clone(),
+ context_server_store.clone(),
+ cx,
+ );
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
- slash_command_working_set.remove(&slash_command_ids);
+ self.slash_commands.remove(&slash_command_ids);
}
}
_ => {}
@@ -884,4 +847,47 @@ impl ContextStore {
}
}
}
+
+ fn load_context_server_slash_commands(
+ &self,
+ server_id: ContextServerId,
+ context_server_store: Entity<ContextServerStore>,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
+ return;
+ };
+ let slash_command_working_set = self.slash_commands.clone();
+ cx.spawn(async move |this, cx| {
+ let Some(protocol) = server.client() else {
+ return;
+ };
+
+ if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
+ if let Some(prompts) = protocol.list_prompts().await.log_err() {
+ let slash_command_ids = prompts
+ .into_iter()
+ .filter(assistant_slash_commands::acceptable_prompt)
+ .map(|prompt| {
+ log::info!("registering context server command: {:?}", prompt.name);
+ slash_command_working_set.insert(Arc::new(
+ assistant_slash_commands::ContextServerSlashCommand::new(
+ context_server_store.clone(),
+ server.id(),
+ prompt,
+ ),
+ ))
+ })
+ .collect::<Vec<_>>();
+
+ this.update(cx, |this, _cx| {
+ this.context_server_slash_command_ids
+ .insert(server_id.clone(), slash_command_ids);
+ })
+ .log_err();
+ }
+ }
+ })
+ .detach();
+ }
}