1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::ContextServerFactoryRegistry;
12use fs::Fs;
13use futures::StreamExt;
14use fuzzy::StringMatchCandidate;
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
16use language::LanguageRegistry;
17use paths::contexts_dir;
18use project::Project;
19use prompt_library::PromptBuilder;
20use regex::Regex;
21use rpc::AnyProtoClient;
22use std::sync::LazyLock;
23use std::{
24 cmp::Reverse,
25 ffi::OsStr,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29 time::Duration,
30};
31use util::{ResultExt, TryFutureExt};
32
33pub(crate) fn init(client: &AnyProtoClient) {
34 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
35 client.add_model_request_handler(ContextStore::handle_open_context);
36 client.add_model_request_handler(ContextStore::handle_create_context);
37 client.add_model_message_handler(ContextStore::handle_update_context);
38 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
39}
40
41#[derive(Clone)]
42pub struct RemoteContextMetadata {
43 pub id: ContextId,
44 pub summary: Option<String>,
45}
46
47pub struct ContextStore {
48 contexts: Vec<ContextHandle>,
49 contexts_metadata: Vec<SavedContextMetadata>,
50 context_server_manager: Entity<ContextServerManager>,
51 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
52 host_contexts: Vec<RemoteContextMetadata>,
53 fs: Arc<dyn Fs>,
54 languages: Arc<LanguageRegistry>,
55 slash_commands: Arc<SlashCommandWorkingSet>,
56 telemetry: Arc<Telemetry>,
57 _watch_updates: Task<Option<()>>,
58 client: Arc<Client>,
59 project: Entity<Project>,
60 project_is_shared: bool,
61 client_subscription: Option<client::Subscription>,
62 _project_subscriptions: Vec<gpui::Subscription>,
63 prompt_builder: Arc<PromptBuilder>,
64}
65
66pub enum ContextStoreEvent {
67 ContextCreated(ContextId),
68}
69
70impl EventEmitter<ContextStoreEvent> for ContextStore {}
71
72enum ContextHandle {
73 Weak(WeakEntity<AssistantContext>),
74 Strong(Entity<AssistantContext>),
75}
76
77impl ContextHandle {
78 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
79 match self {
80 ContextHandle::Weak(weak) => weak.upgrade(),
81 ContextHandle::Strong(strong) => Some(strong.clone()),
82 }
83 }
84
85 fn downgrade(&self) -> WeakEntity<AssistantContext> {
86 match self {
87 ContextHandle::Weak(weak) => weak.clone(),
88 ContextHandle::Strong(strong) => strong.downgrade(),
89 }
90 }
91}
92
93impl ContextStore {
94 pub fn new(
95 project: Entity<Project>,
96 prompt_builder: Arc<PromptBuilder>,
97 slash_commands: Arc<SlashCommandWorkingSet>,
98 cx: &mut App,
99 ) -> Task<Result<Entity<Self>>> {
100 let fs = project.read(cx).fs().clone();
101 let languages = project.read(cx).languages().clone();
102 let telemetry = project.read(cx).client().telemetry().clone();
103 cx.spawn(|mut cx| async move {
104 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
105 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
106
107 let this = cx.new(|cx: &mut Context<Self>| {
108 let context_server_factory_registry =
109 ContextServerFactoryRegistry::default_global(cx);
110 let context_server_manager = cx.new(|cx| {
111 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
112 });
113 let mut this = Self {
114 contexts: Vec::new(),
115 contexts_metadata: Vec::new(),
116 context_server_manager,
117 context_server_slash_command_ids: HashMap::default(),
118 host_contexts: Vec::new(),
119 fs,
120 languages,
121 slash_commands,
122 telemetry,
123 _watch_updates: cx.spawn(|this, mut cx| {
124 async move {
125 while events.next().await.is_some() {
126 this.update(&mut cx, |this, cx| this.reload(cx))?
127 .await
128 .log_err();
129 }
130 anyhow::Ok(())
131 }
132 .log_err()
133 }),
134 client_subscription: None,
135 _project_subscriptions: vec![
136 cx.observe(&project, Self::handle_project_changed),
137 cx.subscribe(&project, Self::handle_project_event),
138 ],
139 project_is_shared: false,
140 client: project.read(cx).client(),
141 project: project.clone(),
142 prompt_builder,
143 };
144 this.handle_project_changed(project.clone(), cx);
145 this.synchronize_contexts(cx);
146 this.register_context_server_handlers(cx);
147 this.reload(cx).detach_and_log_err(cx);
148 this
149 })?;
150
151 Ok(this)
152 })
153 }
154
155 async fn handle_advertise_contexts(
156 this: Entity<Self>,
157 envelope: TypedEnvelope<proto::AdvertiseContexts>,
158 mut cx: AsyncApp,
159 ) -> Result<()> {
160 this.update(&mut cx, |this, cx| {
161 this.host_contexts = envelope
162 .payload
163 .contexts
164 .into_iter()
165 .map(|context| RemoteContextMetadata {
166 id: ContextId::from_proto(context.context_id),
167 summary: context.summary,
168 })
169 .collect();
170 cx.notify();
171 })
172 }
173
174 async fn handle_open_context(
175 this: Entity<Self>,
176 envelope: TypedEnvelope<proto::OpenContext>,
177 mut cx: AsyncApp,
178 ) -> Result<proto::OpenContextResponse> {
179 let context_id = ContextId::from_proto(envelope.payload.context_id);
180 let operations = this.update(&mut cx, |this, cx| {
181 if this.project.read(cx).is_via_collab() {
182 return Err(anyhow!("only the host contexts can be opened"));
183 }
184
185 let context = this
186 .loaded_context_for_id(&context_id, cx)
187 .context("context not found")?;
188 if context.read(cx).replica_id() != ReplicaId::default() {
189 return Err(anyhow!("context must be opened via the host"));
190 }
191
192 anyhow::Ok(
193 context
194 .read(cx)
195 .serialize_ops(&ContextVersion::default(), cx),
196 )
197 })??;
198 let operations = operations.await;
199 Ok(proto::OpenContextResponse {
200 context: Some(proto::Context { operations }),
201 })
202 }
203
204 async fn handle_create_context(
205 this: Entity<Self>,
206 _: TypedEnvelope<proto::CreateContext>,
207 mut cx: AsyncApp,
208 ) -> Result<proto::CreateContextResponse> {
209 let (context_id, operations) = this.update(&mut cx, |this, cx| {
210 if this.project.read(cx).is_via_collab() {
211 return Err(anyhow!("can only create contexts as the host"));
212 }
213
214 let context = this.create(cx);
215 let context_id = context.read(cx).id().clone();
216 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
217
218 anyhow::Ok((
219 context_id,
220 context
221 .read(cx)
222 .serialize_ops(&ContextVersion::default(), cx),
223 ))
224 })??;
225 let operations = operations.await;
226 Ok(proto::CreateContextResponse {
227 context_id: context_id.to_proto(),
228 context: Some(proto::Context { operations }),
229 })
230 }
231
232 async fn handle_update_context(
233 this: Entity<Self>,
234 envelope: TypedEnvelope<proto::UpdateContext>,
235 mut cx: AsyncApp,
236 ) -> Result<()> {
237 this.update(&mut cx, |this, cx| {
238 let context_id = ContextId::from_proto(envelope.payload.context_id);
239 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
240 let operation_proto = envelope.payload.operation.context("invalid operation")?;
241 let operation = ContextOperation::from_proto(operation_proto)?;
242 context.update(cx, |context, cx| context.apply_ops([operation], cx));
243 }
244 Ok(())
245 })?
246 }
247
248 async fn handle_synchronize_contexts(
249 this: Entity<Self>,
250 envelope: TypedEnvelope<proto::SynchronizeContexts>,
251 mut cx: AsyncApp,
252 ) -> Result<proto::SynchronizeContextsResponse> {
253 this.update(&mut cx, |this, cx| {
254 if this.project.read(cx).is_via_collab() {
255 return Err(anyhow!("only the host can synchronize contexts"));
256 }
257
258 let mut local_versions = Vec::new();
259 for remote_version_proto in envelope.payload.contexts {
260 let remote_version = ContextVersion::from_proto(&remote_version_proto);
261 let context_id = ContextId::from_proto(remote_version_proto.context_id);
262 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
263 let context = context.read(cx);
264 let operations = context.serialize_ops(&remote_version, cx);
265 local_versions.push(context.version(cx).to_proto(context_id.clone()));
266 let client = this.client.clone();
267 let project_id = envelope.payload.project_id;
268 cx.background_executor()
269 .spawn(async move {
270 let operations = operations.await;
271 for operation in operations {
272 client.send(proto::UpdateContext {
273 project_id,
274 context_id: context_id.to_proto(),
275 operation: Some(operation),
276 })?;
277 }
278 anyhow::Ok(())
279 })
280 .detach_and_log_err(cx);
281 }
282 }
283
284 this.advertise_contexts(cx);
285
286 anyhow::Ok(proto::SynchronizeContextsResponse {
287 contexts: local_versions,
288 })
289 })?
290 }
291
292 fn handle_project_changed(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
293 let is_shared = self.project.read(cx).is_shared();
294 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
295 if is_shared == was_shared {
296 return;
297 }
298
299 if is_shared {
300 self.contexts.retain_mut(|context| {
301 if let Some(strong_context) = context.upgrade() {
302 *context = ContextHandle::Strong(strong_context);
303 true
304 } else {
305 false
306 }
307 });
308 let remote_id = self.project.read(cx).remote_id().unwrap();
309 self.client_subscription = self
310 .client
311 .subscribe_to_entity(remote_id)
312 .log_err()
313 .map(|subscription| subscription.set_model(&cx.entity(), &mut cx.to_async()));
314 self.advertise_contexts(cx);
315 } else {
316 self.client_subscription = None;
317 }
318 }
319
320 fn handle_project_event(
321 &mut self,
322 _: Entity<Project>,
323 event: &project::Event,
324 cx: &mut Context<Self>,
325 ) {
326 match event {
327 project::Event::Reshared => {
328 self.advertise_contexts(cx);
329 }
330 project::Event::HostReshared | project::Event::Rejoined => {
331 self.synchronize_contexts(cx);
332 }
333 project::Event::DisconnectedFromHost => {
334 self.contexts.retain_mut(|context| {
335 if let Some(strong_context) = context.upgrade() {
336 *context = ContextHandle::Weak(context.downgrade());
337 strong_context.update(cx, |context, cx| {
338 if context.replica_id() != ReplicaId::default() {
339 context.set_capability(language::Capability::ReadOnly, cx);
340 }
341 });
342 true
343 } else {
344 false
345 }
346 });
347 self.host_contexts.clear();
348 cx.notify();
349 }
350 _ => {}
351 }
352 }
353
354 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
355 let context = cx.new(|cx| {
356 AssistantContext::local(
357 self.languages.clone(),
358 Some(self.project.clone()),
359 Some(self.telemetry.clone()),
360 self.prompt_builder.clone(),
361 self.slash_commands.clone(),
362 cx,
363 )
364 });
365 self.register_context(&context, cx);
366 context
367 }
368
369 pub fn create_remote_context(
370 &mut self,
371 cx: &mut Context<Self>,
372 ) -> Task<Result<Entity<AssistantContext>>> {
373 let project = self.project.read(cx);
374 let Some(project_id) = project.remote_id() else {
375 return Task::ready(Err(anyhow!("project was not remote")));
376 };
377
378 let replica_id = project.replica_id();
379 let capability = project.capability();
380 let language_registry = self.languages.clone();
381 let project = self.project.clone();
382 let telemetry = self.telemetry.clone();
383 let prompt_builder = self.prompt_builder.clone();
384 let slash_commands = self.slash_commands.clone();
385 let request = self.client.request(proto::CreateContext { project_id });
386 cx.spawn(|this, mut cx| async move {
387 let response = request.await?;
388 let context_id = ContextId::from_proto(response.context_id);
389 let context_proto = response.context.context("invalid context")?;
390 let context = cx.new(|cx| {
391 AssistantContext::new(
392 context_id.clone(),
393 replica_id,
394 capability,
395 language_registry,
396 prompt_builder,
397 slash_commands,
398 Some(project),
399 Some(telemetry),
400 cx,
401 )
402 })?;
403 let operations = cx
404 .background_executor()
405 .spawn(async move {
406 context_proto
407 .operations
408 .into_iter()
409 .map(ContextOperation::from_proto)
410 .collect::<Result<Vec<_>>>()
411 })
412 .await?;
413 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
414 this.update(&mut cx, |this, cx| {
415 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
416 existing_context
417 } else {
418 this.register_context(&context, cx);
419 this.synchronize_contexts(cx);
420 context
421 }
422 })
423 })
424 }
425
426 pub fn open_local_context(
427 &mut self,
428 path: PathBuf,
429 cx: &Context<Self>,
430 ) -> Task<Result<Entity<AssistantContext>>> {
431 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
432 return Task::ready(Ok(existing_context));
433 }
434
435 let fs = self.fs.clone();
436 let languages = self.languages.clone();
437 let project = self.project.clone();
438 let telemetry = self.telemetry.clone();
439 let load = cx.background_executor().spawn({
440 let path = path.clone();
441 async move {
442 let saved_context = fs.load(&path).await?;
443 SavedContext::from_json(&saved_context)
444 }
445 });
446 let prompt_builder = self.prompt_builder.clone();
447 let slash_commands = self.slash_commands.clone();
448
449 cx.spawn(|this, mut cx| async move {
450 let saved_context = load.await?;
451 let context = cx.new(|cx| {
452 AssistantContext::deserialize(
453 saved_context,
454 path.clone(),
455 languages,
456 prompt_builder,
457 slash_commands,
458 Some(project),
459 Some(telemetry),
460 cx,
461 )
462 })?;
463 this.update(&mut cx, |this, cx| {
464 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
465 existing_context
466 } else {
467 this.register_context(&context, cx);
468 context
469 }
470 })
471 })
472 }
473
474 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
475 self.contexts.iter().find_map(|context| {
476 let context = context.upgrade()?;
477 if context.read(cx).path() == Some(path) {
478 Some(context)
479 } else {
480 None
481 }
482 })
483 }
484
485 pub fn loaded_context_for_id(
486 &self,
487 id: &ContextId,
488 cx: &App,
489 ) -> Option<Entity<AssistantContext>> {
490 self.contexts.iter().find_map(|context| {
491 let context = context.upgrade()?;
492 if context.read(cx).id() == id {
493 Some(context)
494 } else {
495 None
496 }
497 })
498 }
499
500 pub fn open_remote_context(
501 &mut self,
502 context_id: ContextId,
503 cx: &mut Context<Self>,
504 ) -> Task<Result<Entity<AssistantContext>>> {
505 let project = self.project.read(cx);
506 let Some(project_id) = project.remote_id() else {
507 return Task::ready(Err(anyhow!("project was not remote")));
508 };
509
510 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
511 return Task::ready(Ok(context));
512 }
513
514 let replica_id = project.replica_id();
515 let capability = project.capability();
516 let language_registry = self.languages.clone();
517 let project = self.project.clone();
518 let telemetry = self.telemetry.clone();
519 let request = self.client.request(proto::OpenContext {
520 project_id,
521 context_id: context_id.to_proto(),
522 });
523 let prompt_builder = self.prompt_builder.clone();
524 let slash_commands = self.slash_commands.clone();
525 cx.spawn(|this, mut cx| async move {
526 let response = request.await?;
527 let context_proto = response.context.context("invalid context")?;
528 let context = cx.new(|cx| {
529 AssistantContext::new(
530 context_id.clone(),
531 replica_id,
532 capability,
533 language_registry,
534 prompt_builder,
535 slash_commands,
536 Some(project),
537 Some(telemetry),
538 cx,
539 )
540 })?;
541 let operations = cx
542 .background_executor()
543 .spawn(async move {
544 context_proto
545 .operations
546 .into_iter()
547 .map(ContextOperation::from_proto)
548 .collect::<Result<Vec<_>>>()
549 })
550 .await?;
551 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
552 this.update(&mut cx, |this, cx| {
553 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
554 existing_context
555 } else {
556 this.register_context(&context, cx);
557 this.synchronize_contexts(cx);
558 context
559 }
560 })
561 })
562 }
563
564 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
565 let handle = if self.project_is_shared {
566 ContextHandle::Strong(context.clone())
567 } else {
568 ContextHandle::Weak(context.downgrade())
569 };
570 self.contexts.push(handle);
571 self.advertise_contexts(cx);
572 cx.subscribe(context, Self::handle_context_event).detach();
573 }
574
575 fn handle_context_event(
576 &mut self,
577 context: Entity<AssistantContext>,
578 event: &ContextEvent,
579 cx: &mut Context<Self>,
580 ) {
581 let Some(project_id) = self.project.read(cx).remote_id() else {
582 return;
583 };
584
585 match event {
586 ContextEvent::SummaryChanged => {
587 self.advertise_contexts(cx);
588 }
589 ContextEvent::Operation(operation) => {
590 let context_id = context.read(cx).id().to_proto();
591 let operation = operation.to_proto();
592 self.client
593 .send(proto::UpdateContext {
594 project_id,
595 context_id,
596 operation: Some(operation),
597 })
598 .log_err();
599 }
600 _ => {}
601 }
602 }
603
604 fn advertise_contexts(&self, cx: &App) {
605 let Some(project_id) = self.project.read(cx).remote_id() else {
606 return;
607 };
608
609 // For now, only the host can advertise their open contexts.
610 if self.project.read(cx).is_via_collab() {
611 return;
612 }
613
614 let contexts = self
615 .contexts
616 .iter()
617 .rev()
618 .filter_map(|context| {
619 let context = context.upgrade()?.read(cx);
620 if context.replica_id() == ReplicaId::default() {
621 Some(proto::ContextMetadata {
622 context_id: context.id().to_proto(),
623 summary: context.summary().map(|summary| summary.text.clone()),
624 })
625 } else {
626 None
627 }
628 })
629 .collect();
630 self.client
631 .send(proto::AdvertiseContexts {
632 project_id,
633 contexts,
634 })
635 .ok();
636 }
637
638 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
639 let Some(project_id) = self.project.read(cx).remote_id() else {
640 return;
641 };
642
643 let contexts = self
644 .contexts
645 .iter()
646 .filter_map(|context| {
647 let context = context.upgrade()?.read(cx);
648 if context.replica_id() != ReplicaId::default() {
649 Some(context.version(cx).to_proto(context.id().clone()))
650 } else {
651 None
652 }
653 })
654 .collect();
655
656 let client = self.client.clone();
657 let request = self.client.request(proto::SynchronizeContexts {
658 project_id,
659 contexts,
660 });
661 cx.spawn(|this, cx| async move {
662 let response = request.await?;
663
664 let mut context_ids = Vec::new();
665 let mut operations = Vec::new();
666 this.read_with(&cx, |this, cx| {
667 for context_version_proto in response.contexts {
668 let context_version = ContextVersion::from_proto(&context_version_proto);
669 let context_id = ContextId::from_proto(context_version_proto.context_id);
670 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
671 context_ids.push(context_id);
672 operations.push(context.read(cx).serialize_ops(&context_version, cx));
673 }
674 }
675 })?;
676
677 let operations = futures::future::join_all(operations).await;
678 for (context_id, operations) in context_ids.into_iter().zip(operations) {
679 for operation in operations {
680 client.send(proto::UpdateContext {
681 project_id,
682 context_id: context_id.to_proto(),
683 operation: Some(operation),
684 })?;
685 }
686 }
687
688 anyhow::Ok(())
689 })
690 .detach_and_log_err(cx);
691 }
692
693 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
694 let metadata = self.contexts_metadata.clone();
695 let executor = cx.background_executor().clone();
696 cx.background_executor().spawn(async move {
697 if query.is_empty() {
698 metadata
699 } else {
700 let candidates = metadata
701 .iter()
702 .enumerate()
703 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
704 .collect::<Vec<_>>();
705 let matches = fuzzy::match_strings(
706 &candidates,
707 &query,
708 false,
709 100,
710 &Default::default(),
711 executor,
712 )
713 .await;
714
715 matches
716 .into_iter()
717 .map(|mat| metadata[mat.candidate_id].clone())
718 .collect()
719 }
720 })
721 }
722
723 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
724 &self.host_contexts
725 }
726
727 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
728 let fs = self.fs.clone();
729 cx.spawn(|this, mut cx| async move {
730 fs.create_dir(contexts_dir()).await?;
731
732 let mut paths = fs.read_dir(contexts_dir()).await?;
733 let mut contexts = Vec::<SavedContextMetadata>::new();
734 while let Some(path) = paths.next().await {
735 let path = path?;
736 if path.extension() != Some(OsStr::new("json")) {
737 continue;
738 }
739
740 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
741 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
742
743 let metadata = fs.metadata(&path).await?;
744 if let Some((file_name, metadata)) = path
745 .file_name()
746 .and_then(|name| name.to_str())
747 .zip(metadata)
748 {
749 // This is used to filter out contexts saved by the new assistant.
750 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
751 continue;
752 }
753
754 if let Some(title) = ASSISTANT_CONTEXT_REGEX
755 .replace(file_name, "")
756 .lines()
757 .next()
758 {
759 contexts.push(SavedContextMetadata {
760 title: title.to_string(),
761 path,
762 mtime: metadata.mtime.timestamp_for_user().into(),
763 });
764 }
765 }
766 }
767 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
768
769 this.update(&mut cx, |this, cx| {
770 this.contexts_metadata = contexts;
771 cx.notify();
772 })
773 })
774 }
775
776 pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
777 cx.update_entity(
778 &self.context_server_manager,
779 |context_server_manager, cx| {
780 for server in context_server_manager.servers() {
781 context_server_manager
782 .restart_server(&server.id(), cx)
783 .detach_and_log_err(cx);
784 }
785 },
786 );
787 }
788
789 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
790 cx.subscribe(
791 &self.context_server_manager.clone(),
792 Self::handle_context_server_event,
793 )
794 .detach();
795 }
796
797 fn handle_context_server_event(
798 &mut self,
799 context_server_manager: Entity<ContextServerManager>,
800 event: &context_server::manager::Event,
801 cx: &mut Context<Self>,
802 ) {
803 let slash_command_working_set = self.slash_commands.clone();
804 match event {
805 context_server::manager::Event::ServerStarted { server_id } => {
806 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
807 let context_server_manager = context_server_manager.clone();
808 cx.spawn({
809 let server = server.clone();
810 let server_id = server_id.clone();
811 |this, mut cx| async move {
812 let Some(protocol) = server.client() else {
813 return;
814 };
815
816 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
817 if let Some(prompts) = protocol.list_prompts().await.log_err() {
818 let slash_command_ids = prompts
819 .into_iter()
820 .filter(assistant_slash_commands::acceptable_prompt)
821 .map(|prompt| {
822 log::info!(
823 "registering context server command: {:?}",
824 prompt.name
825 );
826 slash_command_working_set.insert(Arc::new(
827 assistant_slash_commands::ContextServerSlashCommand::new(
828 context_server_manager.clone(),
829 &server,
830 prompt,
831 ),
832 ))
833 })
834 .collect::<Vec<_>>();
835
836 this.update(&mut cx, |this, _cx| {
837 this.context_server_slash_command_ids
838 .insert(server_id.clone(), slash_command_ids);
839 })
840 .log_err();
841 }
842 }
843 }
844 })
845 .detach();
846 }
847 }
848 context_server::manager::Event::ServerStopped { server_id } => {
849 if let Some(slash_command_ids) =
850 self.context_server_slash_command_ids.remove(server_id)
851 {
852 slash_command_working_set.remove(&slash_command_ids);
853 }
854 }
855 }
856 }
857}