1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::ContextServerFactoryRegistry;
12use fs::Fs;
13use futures::StreamExt;
14use fuzzy::StringMatchCandidate;
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
16use language::LanguageRegistry;
17use paths::contexts_dir;
18use project::Project;
19use prompt_library::PromptBuilder;
20use regex::Regex;
21use rpc::AnyProtoClient;
22use std::sync::LazyLock;
23use std::{
24 cmp::Reverse,
25 ffi::OsStr,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29 time::Duration,
30};
31use util::{ResultExt, TryFutureExt};
32
33pub(crate) fn init(client: &AnyProtoClient) {
34 client.add_entity_message_handler(ContextStore::handle_advertise_contexts);
35 client.add_entity_request_handler(ContextStore::handle_open_context);
36 client.add_entity_request_handler(ContextStore::handle_create_context);
37 client.add_entity_message_handler(ContextStore::handle_update_context);
38 client.add_entity_request_handler(ContextStore::handle_synchronize_contexts);
39}
40
41#[derive(Clone)]
42pub struct RemoteContextMetadata {
43 pub id: ContextId,
44 pub summary: Option<String>,
45}
46
47pub struct ContextStore {
48 contexts: Vec<ContextHandle>,
49 contexts_metadata: Vec<SavedContextMetadata>,
50 context_server_manager: Entity<ContextServerManager>,
51 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
52 host_contexts: Vec<RemoteContextMetadata>,
53 fs: Arc<dyn Fs>,
54 languages: Arc<LanguageRegistry>,
55 slash_commands: Arc<SlashCommandWorkingSet>,
56 telemetry: Arc<Telemetry>,
57 _watch_updates: Task<Option<()>>,
58 client: Arc<Client>,
59 project: Entity<Project>,
60 project_is_shared: bool,
61 client_subscription: Option<client::Subscription>,
62 _project_subscriptions: Vec<gpui::Subscription>,
63 prompt_builder: Arc<PromptBuilder>,
64}
65
66pub enum ContextStoreEvent {
67 ContextCreated(ContextId),
68}
69
70impl EventEmitter<ContextStoreEvent> for ContextStore {}
71
72enum ContextHandle {
73 Weak(WeakEntity<AssistantContext>),
74 Strong(Entity<AssistantContext>),
75}
76
77impl ContextHandle {
78 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
79 match self {
80 ContextHandle::Weak(weak) => weak.upgrade(),
81 ContextHandle::Strong(strong) => Some(strong.clone()),
82 }
83 }
84
85 fn downgrade(&self) -> WeakEntity<AssistantContext> {
86 match self {
87 ContextHandle::Weak(weak) => weak.clone(),
88 ContextHandle::Strong(strong) => strong.downgrade(),
89 }
90 }
91}
92
93impl ContextStore {
94 pub fn new(
95 project: Entity<Project>,
96 prompt_builder: Arc<PromptBuilder>,
97 slash_commands: Arc<SlashCommandWorkingSet>,
98 cx: &mut App,
99 ) -> Task<Result<Entity<Self>>> {
100 let fs = project.read(cx).fs().clone();
101 let languages = project.read(cx).languages().clone();
102 let telemetry = project.read(cx).client().telemetry().clone();
103 cx.spawn(|mut cx| async move {
104 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
105 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
106
107 let this = cx.new(|cx: &mut Context<Self>| {
108 let context_server_factory_registry =
109 ContextServerFactoryRegistry::default_global(cx);
110 let context_server_manager = cx.new(|cx| {
111 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
112 });
113 let mut this = Self {
114 contexts: Vec::new(),
115 contexts_metadata: Vec::new(),
116 context_server_manager,
117 context_server_slash_command_ids: HashMap::default(),
118 host_contexts: Vec::new(),
119 fs,
120 languages,
121 slash_commands,
122 telemetry,
123 _watch_updates: cx.spawn(|this, mut cx| {
124 async move {
125 while events.next().await.is_some() {
126 this.update(&mut cx, |this, cx| this.reload(cx))?
127 .await
128 .log_err();
129 }
130 anyhow::Ok(())
131 }
132 .log_err()
133 }),
134 client_subscription: None,
135 _project_subscriptions: vec![
136 cx.observe(&project, Self::handle_project_changed),
137 cx.subscribe(&project, Self::handle_project_event),
138 ],
139 project_is_shared: false,
140 client: project.read(cx).client(),
141 project: project.clone(),
142 prompt_builder,
143 };
144 this.handle_project_changed(project.clone(), cx);
145 this.synchronize_contexts(cx);
146 this.register_context_server_handlers(cx);
147 this.reload(cx).detach_and_log_err(cx);
148 this
149 })?;
150
151 Ok(this)
152 })
153 }
154
155 async fn handle_advertise_contexts(
156 this: Entity<Self>,
157 envelope: TypedEnvelope<proto::AdvertiseContexts>,
158 mut cx: AsyncApp,
159 ) -> Result<()> {
160 this.update(&mut cx, |this, cx| {
161 this.host_contexts = envelope
162 .payload
163 .contexts
164 .into_iter()
165 .map(|context| RemoteContextMetadata {
166 id: ContextId::from_proto(context.context_id),
167 summary: context.summary,
168 })
169 .collect();
170 cx.notify();
171 })
172 }
173
174 async fn handle_open_context(
175 this: Entity<Self>,
176 envelope: TypedEnvelope<proto::OpenContext>,
177 mut cx: AsyncApp,
178 ) -> Result<proto::OpenContextResponse> {
179 let context_id = ContextId::from_proto(envelope.payload.context_id);
180 let operations = this.update(&mut cx, |this, cx| {
181 if this.project.read(cx).is_via_collab() {
182 return Err(anyhow!("only the host contexts can be opened"));
183 }
184
185 let context = this
186 .loaded_context_for_id(&context_id, cx)
187 .context("context not found")?;
188 if context.read(cx).replica_id() != ReplicaId::default() {
189 return Err(anyhow!("context must be opened via the host"));
190 }
191
192 anyhow::Ok(
193 context
194 .read(cx)
195 .serialize_ops(&ContextVersion::default(), cx),
196 )
197 })??;
198 let operations = operations.await;
199 Ok(proto::OpenContextResponse {
200 context: Some(proto::Context { operations }),
201 })
202 }
203
204 async fn handle_create_context(
205 this: Entity<Self>,
206 _: TypedEnvelope<proto::CreateContext>,
207 mut cx: AsyncApp,
208 ) -> Result<proto::CreateContextResponse> {
209 let (context_id, operations) = this.update(&mut cx, |this, cx| {
210 if this.project.read(cx).is_via_collab() {
211 return Err(anyhow!("can only create contexts as the host"));
212 }
213
214 let context = this.create(cx);
215 let context_id = context.read(cx).id().clone();
216 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
217
218 anyhow::Ok((
219 context_id,
220 context
221 .read(cx)
222 .serialize_ops(&ContextVersion::default(), cx),
223 ))
224 })??;
225 let operations = operations.await;
226 Ok(proto::CreateContextResponse {
227 context_id: context_id.to_proto(),
228 context: Some(proto::Context { operations }),
229 })
230 }
231
232 async fn handle_update_context(
233 this: Entity<Self>,
234 envelope: TypedEnvelope<proto::UpdateContext>,
235 mut cx: AsyncApp,
236 ) -> Result<()> {
237 this.update(&mut cx, |this, cx| {
238 let context_id = ContextId::from_proto(envelope.payload.context_id);
239 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
240 let operation_proto = envelope.payload.operation.context("invalid operation")?;
241 let operation = ContextOperation::from_proto(operation_proto)?;
242 context.update(cx, |context, cx| context.apply_ops([operation], cx));
243 }
244 Ok(())
245 })?
246 }
247
248 async fn handle_synchronize_contexts(
249 this: Entity<Self>,
250 envelope: TypedEnvelope<proto::SynchronizeContexts>,
251 mut cx: AsyncApp,
252 ) -> Result<proto::SynchronizeContextsResponse> {
253 this.update(&mut cx, |this, cx| {
254 if this.project.read(cx).is_via_collab() {
255 return Err(anyhow!("only the host can synchronize contexts"));
256 }
257
258 let mut local_versions = Vec::new();
259 for remote_version_proto in envelope.payload.contexts {
260 let remote_version = ContextVersion::from_proto(&remote_version_proto);
261 let context_id = ContextId::from_proto(remote_version_proto.context_id);
262 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
263 let context = context.read(cx);
264 let operations = context.serialize_ops(&remote_version, cx);
265 local_versions.push(context.version(cx).to_proto(context_id.clone()));
266 let client = this.client.clone();
267 let project_id = envelope.payload.project_id;
268 cx.background_spawn(async move {
269 let operations = operations.await;
270 for operation in operations {
271 client.send(proto::UpdateContext {
272 project_id,
273 context_id: context_id.to_proto(),
274 operation: Some(operation),
275 })?;
276 }
277 anyhow::Ok(())
278 })
279 .detach_and_log_err(cx);
280 }
281 }
282
283 this.advertise_contexts(cx);
284
285 anyhow::Ok(proto::SynchronizeContextsResponse {
286 contexts: local_versions,
287 })
288 })?
289 }
290
291 fn handle_project_changed(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
292 let is_shared = self.project.read(cx).is_shared();
293 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
294 if is_shared == was_shared {
295 return;
296 }
297
298 if is_shared {
299 self.contexts.retain_mut(|context| {
300 if let Some(strong_context) = context.upgrade() {
301 *context = ContextHandle::Strong(strong_context);
302 true
303 } else {
304 false
305 }
306 });
307 let remote_id = self.project.read(cx).remote_id().unwrap();
308 self.client_subscription = self
309 .client
310 .subscribe_to_entity(remote_id)
311 .log_err()
312 .map(|subscription| subscription.set_entity(&cx.entity(), &mut cx.to_async()));
313 self.advertise_contexts(cx);
314 } else {
315 self.client_subscription = None;
316 }
317 }
318
319 fn handle_project_event(
320 &mut self,
321 _: Entity<Project>,
322 event: &project::Event,
323 cx: &mut Context<Self>,
324 ) {
325 match event {
326 project::Event::Reshared => {
327 self.advertise_contexts(cx);
328 }
329 project::Event::HostReshared | project::Event::Rejoined => {
330 self.synchronize_contexts(cx);
331 }
332 project::Event::DisconnectedFromHost => {
333 self.contexts.retain_mut(|context| {
334 if let Some(strong_context) = context.upgrade() {
335 *context = ContextHandle::Weak(context.downgrade());
336 strong_context.update(cx, |context, cx| {
337 if context.replica_id() != ReplicaId::default() {
338 context.set_capability(language::Capability::ReadOnly, cx);
339 }
340 });
341 true
342 } else {
343 false
344 }
345 });
346 self.host_contexts.clear();
347 cx.notify();
348 }
349 _ => {}
350 }
351 }
352
353 pub fn contexts(&self) -> Vec<SavedContextMetadata> {
354 let mut contexts = self.contexts_metadata.iter().cloned().collect::<Vec<_>>();
355 contexts.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.mtime));
356 contexts
357 }
358
359 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
360 let context = cx.new(|cx| {
361 AssistantContext::local(
362 self.languages.clone(),
363 Some(self.project.clone()),
364 Some(self.telemetry.clone()),
365 self.prompt_builder.clone(),
366 self.slash_commands.clone(),
367 cx,
368 )
369 });
370 self.register_context(&context, cx);
371 context
372 }
373
374 pub fn create_remote_context(
375 &mut self,
376 cx: &mut Context<Self>,
377 ) -> Task<Result<Entity<AssistantContext>>> {
378 let project = self.project.read(cx);
379 let Some(project_id) = project.remote_id() else {
380 return Task::ready(Err(anyhow!("project was not remote")));
381 };
382
383 let replica_id = project.replica_id();
384 let capability = project.capability();
385 let language_registry = self.languages.clone();
386 let project = self.project.clone();
387 let telemetry = self.telemetry.clone();
388 let prompt_builder = self.prompt_builder.clone();
389 let slash_commands = self.slash_commands.clone();
390 let request = self.client.request(proto::CreateContext { project_id });
391 cx.spawn(|this, mut cx| async move {
392 let response = request.await?;
393 let context_id = ContextId::from_proto(response.context_id);
394 let context_proto = response.context.context("invalid context")?;
395 let context = cx.new(|cx| {
396 AssistantContext::new(
397 context_id.clone(),
398 replica_id,
399 capability,
400 language_registry,
401 prompt_builder,
402 slash_commands,
403 Some(project),
404 Some(telemetry),
405 cx,
406 )
407 })?;
408 let operations = cx
409 .background_spawn(async move {
410 context_proto
411 .operations
412 .into_iter()
413 .map(ContextOperation::from_proto)
414 .collect::<Result<Vec<_>>>()
415 })
416 .await?;
417 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
418 this.update(&mut cx, |this, cx| {
419 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
420 existing_context
421 } else {
422 this.register_context(&context, cx);
423 this.synchronize_contexts(cx);
424 context
425 }
426 })
427 })
428 }
429
430 pub fn open_local_context(
431 &mut self,
432 path: PathBuf,
433 cx: &Context<Self>,
434 ) -> Task<Result<Entity<AssistantContext>>> {
435 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
436 return Task::ready(Ok(existing_context));
437 }
438
439 let fs = self.fs.clone();
440 let languages = self.languages.clone();
441 let project = self.project.clone();
442 let telemetry = self.telemetry.clone();
443 let load = cx.background_spawn({
444 let path = path.clone();
445 async move {
446 let saved_context = fs.load(&path).await?;
447 SavedContext::from_json(&saved_context)
448 }
449 });
450 let prompt_builder = self.prompt_builder.clone();
451 let slash_commands = self.slash_commands.clone();
452
453 cx.spawn(|this, mut cx| async move {
454 let saved_context = load.await?;
455 let context = cx.new(|cx| {
456 AssistantContext::deserialize(
457 saved_context,
458 path.clone(),
459 languages,
460 prompt_builder,
461 slash_commands,
462 Some(project),
463 Some(telemetry),
464 cx,
465 )
466 })?;
467 this.update(&mut cx, |this, cx| {
468 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
469 existing_context
470 } else {
471 this.register_context(&context, cx);
472 context
473 }
474 })
475 })
476 }
477
478 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
479 self.contexts.iter().find_map(|context| {
480 let context = context.upgrade()?;
481 if context.read(cx).path() == Some(path) {
482 Some(context)
483 } else {
484 None
485 }
486 })
487 }
488
489 pub fn loaded_context_for_id(
490 &self,
491 id: &ContextId,
492 cx: &App,
493 ) -> Option<Entity<AssistantContext>> {
494 self.contexts.iter().find_map(|context| {
495 let context = context.upgrade()?;
496 if context.read(cx).id() == id {
497 Some(context)
498 } else {
499 None
500 }
501 })
502 }
503
504 pub fn open_remote_context(
505 &mut self,
506 context_id: ContextId,
507 cx: &mut Context<Self>,
508 ) -> Task<Result<Entity<AssistantContext>>> {
509 let project = self.project.read(cx);
510 let Some(project_id) = project.remote_id() else {
511 return Task::ready(Err(anyhow!("project was not remote")));
512 };
513
514 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
515 return Task::ready(Ok(context));
516 }
517
518 let replica_id = project.replica_id();
519 let capability = project.capability();
520 let language_registry = self.languages.clone();
521 let project = self.project.clone();
522 let telemetry = self.telemetry.clone();
523 let request = self.client.request(proto::OpenContext {
524 project_id,
525 context_id: context_id.to_proto(),
526 });
527 let prompt_builder = self.prompt_builder.clone();
528 let slash_commands = self.slash_commands.clone();
529 cx.spawn(|this, mut cx| async move {
530 let response = request.await?;
531 let context_proto = response.context.context("invalid context")?;
532 let context = cx.new(|cx| {
533 AssistantContext::new(
534 context_id.clone(),
535 replica_id,
536 capability,
537 language_registry,
538 prompt_builder,
539 slash_commands,
540 Some(project),
541 Some(telemetry),
542 cx,
543 )
544 })?;
545 let operations = cx
546 .background_spawn(async move {
547 context_proto
548 .operations
549 .into_iter()
550 .map(ContextOperation::from_proto)
551 .collect::<Result<Vec<_>>>()
552 })
553 .await?;
554 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
555 this.update(&mut cx, |this, cx| {
556 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
557 existing_context
558 } else {
559 this.register_context(&context, cx);
560 this.synchronize_contexts(cx);
561 context
562 }
563 })
564 })
565 }
566
567 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
568 let handle = if self.project_is_shared {
569 ContextHandle::Strong(context.clone())
570 } else {
571 ContextHandle::Weak(context.downgrade())
572 };
573 self.contexts.push(handle);
574 self.advertise_contexts(cx);
575 cx.subscribe(context, Self::handle_context_event).detach();
576 }
577
578 fn handle_context_event(
579 &mut self,
580 context: Entity<AssistantContext>,
581 event: &ContextEvent,
582 cx: &mut Context<Self>,
583 ) {
584 let Some(project_id) = self.project.read(cx).remote_id() else {
585 return;
586 };
587
588 match event {
589 ContextEvent::SummaryChanged => {
590 self.advertise_contexts(cx);
591 }
592 ContextEvent::Operation(operation) => {
593 let context_id = context.read(cx).id().to_proto();
594 let operation = operation.to_proto();
595 self.client
596 .send(proto::UpdateContext {
597 project_id,
598 context_id,
599 operation: Some(operation),
600 })
601 .log_err();
602 }
603 _ => {}
604 }
605 }
606
607 fn advertise_contexts(&self, cx: &App) {
608 let Some(project_id) = self.project.read(cx).remote_id() else {
609 return;
610 };
611
612 // For now, only the host can advertise their open contexts.
613 if self.project.read(cx).is_via_collab() {
614 return;
615 }
616
617 let contexts = self
618 .contexts
619 .iter()
620 .rev()
621 .filter_map(|context| {
622 let context = context.upgrade()?.read(cx);
623 if context.replica_id() == ReplicaId::default() {
624 Some(proto::ContextMetadata {
625 context_id: context.id().to_proto(),
626 summary: context.summary().map(|summary| summary.text.clone()),
627 })
628 } else {
629 None
630 }
631 })
632 .collect();
633 self.client
634 .send(proto::AdvertiseContexts {
635 project_id,
636 contexts,
637 })
638 .ok();
639 }
640
641 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
642 let Some(project_id) = self.project.read(cx).remote_id() else {
643 return;
644 };
645
646 let contexts = self
647 .contexts
648 .iter()
649 .filter_map(|context| {
650 let context = context.upgrade()?.read(cx);
651 if context.replica_id() != ReplicaId::default() {
652 Some(context.version(cx).to_proto(context.id().clone()))
653 } else {
654 None
655 }
656 })
657 .collect();
658
659 let client = self.client.clone();
660 let request = self.client.request(proto::SynchronizeContexts {
661 project_id,
662 contexts,
663 });
664 cx.spawn(|this, cx| async move {
665 let response = request.await?;
666
667 let mut context_ids = Vec::new();
668 let mut operations = Vec::new();
669 this.read_with(&cx, |this, cx| {
670 for context_version_proto in response.contexts {
671 let context_version = ContextVersion::from_proto(&context_version_proto);
672 let context_id = ContextId::from_proto(context_version_proto.context_id);
673 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
674 context_ids.push(context_id);
675 operations.push(context.read(cx).serialize_ops(&context_version, cx));
676 }
677 }
678 })?;
679
680 let operations = futures::future::join_all(operations).await;
681 for (context_id, operations) in context_ids.into_iter().zip(operations) {
682 for operation in operations {
683 client.send(proto::UpdateContext {
684 project_id,
685 context_id: context_id.to_proto(),
686 operation: Some(operation),
687 })?;
688 }
689 }
690
691 anyhow::Ok(())
692 })
693 .detach_and_log_err(cx);
694 }
695
696 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
697 let metadata = self.contexts_metadata.clone();
698 let executor = cx.background_executor().clone();
699 cx.background_spawn(async move {
700 if query.is_empty() {
701 metadata
702 } else {
703 let candidates = metadata
704 .iter()
705 .enumerate()
706 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
707 .collect::<Vec<_>>();
708 let matches = fuzzy::match_strings(
709 &candidates,
710 &query,
711 false,
712 100,
713 &Default::default(),
714 executor,
715 )
716 .await;
717
718 matches
719 .into_iter()
720 .map(|mat| metadata[mat.candidate_id].clone())
721 .collect()
722 }
723 })
724 }
725
726 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
727 &self.host_contexts
728 }
729
730 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
731 let fs = self.fs.clone();
732 cx.spawn(|this, mut cx| async move {
733 fs.create_dir(contexts_dir()).await?;
734
735 let mut paths = fs.read_dir(contexts_dir()).await?;
736 let mut contexts = Vec::<SavedContextMetadata>::new();
737 while let Some(path) = paths.next().await {
738 let path = path?;
739 if path.extension() != Some(OsStr::new("json")) {
740 continue;
741 }
742
743 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
744 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
745
746 let metadata = fs.metadata(&path).await?;
747 if let Some((file_name, metadata)) = path
748 .file_name()
749 .and_then(|name| name.to_str())
750 .zip(metadata)
751 {
752 // This is used to filter out contexts saved by the new assistant.
753 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
754 continue;
755 }
756
757 if let Some(title) = ASSISTANT_CONTEXT_REGEX
758 .replace(file_name, "")
759 .lines()
760 .next()
761 {
762 contexts.push(SavedContextMetadata {
763 title: title.to_string(),
764 path,
765 mtime: metadata.mtime.timestamp_for_user().into(),
766 });
767 }
768 }
769 }
770 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
771
772 this.update(&mut cx, |this, cx| {
773 this.contexts_metadata = contexts;
774 cx.notify();
775 })
776 })
777 }
778
779 pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
780 cx.update_entity(
781 &self.context_server_manager,
782 |context_server_manager, cx| {
783 for server in context_server_manager.servers() {
784 context_server_manager
785 .restart_server(&server.id(), cx)
786 .detach_and_log_err(cx);
787 }
788 },
789 );
790 }
791
792 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
793 cx.subscribe(
794 &self.context_server_manager.clone(),
795 Self::handle_context_server_event,
796 )
797 .detach();
798 }
799
800 fn handle_context_server_event(
801 &mut self,
802 context_server_manager: Entity<ContextServerManager>,
803 event: &context_server::manager::Event,
804 cx: &mut Context<Self>,
805 ) {
806 let slash_command_working_set = self.slash_commands.clone();
807 match event {
808 context_server::manager::Event::ServerStarted { server_id } => {
809 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
810 let context_server_manager = context_server_manager.clone();
811 cx.spawn({
812 let server = server.clone();
813 let server_id = server_id.clone();
814 |this, mut cx| async move {
815 let Some(protocol) = server.client() else {
816 return;
817 };
818
819 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
820 if let Some(prompts) = protocol.list_prompts().await.log_err() {
821 let slash_command_ids = prompts
822 .into_iter()
823 .filter(assistant_slash_commands::acceptable_prompt)
824 .map(|prompt| {
825 log::info!(
826 "registering context server command: {:?}",
827 prompt.name
828 );
829 slash_command_working_set.insert(Arc::new(
830 assistant_slash_commands::ContextServerSlashCommand::new(
831 context_server_manager.clone(),
832 &server,
833 prompt,
834 ),
835 ))
836 })
837 .collect::<Vec<_>>();
838
839 this.update(&mut cx, |this, _cx| {
840 this.context_server_slash_command_ids
841 .insert(server_id.clone(), slash_command_ids);
842 })
843 .log_err();
844 }
845 }
846 }
847 })
848 .detach();
849 }
850 }
851 context_server::manager::Event::ServerStopped { server_id } => {
852 if let Some(slash_command_ids) =
853 self.context_server_slash_command_ids.remove(server_id)
854 {
855 slash_command_working_set.remove(&slash_command_ids);
856 }
857 }
858 }
859 }
860}