1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::ContextServerFactoryRegistry;
12use fs::Fs;
13use futures::StreamExt;
14use fuzzy::StringMatchCandidate;
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
16use language::LanguageRegistry;
17use paths::contexts_dir;
18use project::Project;
19use prompt_library::PromptBuilder;
20use regex::Regex;
21use rpc::AnyProtoClient;
22use std::sync::LazyLock;
23use std::{
24 cmp::Reverse,
25 ffi::OsStr,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29 time::Duration,
30};
31use util::{ResultExt, TryFutureExt};
32
33pub(crate) fn init(client: &AnyProtoClient) {
34 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
35 client.add_model_request_handler(ContextStore::handle_open_context);
36 client.add_model_request_handler(ContextStore::handle_create_context);
37 client.add_model_message_handler(ContextStore::handle_update_context);
38 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
39}
40
41#[derive(Clone)]
42pub struct RemoteContextMetadata {
43 pub id: ContextId,
44 pub summary: Option<String>,
45}
46
47pub struct ContextStore {
48 contexts: Vec<ContextHandle>,
49 contexts_metadata: Vec<SavedContextMetadata>,
50 context_server_manager: Entity<ContextServerManager>,
51 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
52 host_contexts: Vec<RemoteContextMetadata>,
53 fs: Arc<dyn Fs>,
54 languages: Arc<LanguageRegistry>,
55 slash_commands: Arc<SlashCommandWorkingSet>,
56 telemetry: Arc<Telemetry>,
57 _watch_updates: Task<Option<()>>,
58 client: Arc<Client>,
59 project: Entity<Project>,
60 project_is_shared: bool,
61 client_subscription: Option<client::Subscription>,
62 _project_subscriptions: Vec<gpui::Subscription>,
63 prompt_builder: Arc<PromptBuilder>,
64}
65
66pub enum ContextStoreEvent {
67 ContextCreated(ContextId),
68}
69
70impl EventEmitter<ContextStoreEvent> for ContextStore {}
71
72enum ContextHandle {
73 Weak(WeakEntity<AssistantContext>),
74 Strong(Entity<AssistantContext>),
75}
76
77impl ContextHandle {
78 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
79 match self {
80 ContextHandle::Weak(weak) => weak.upgrade(),
81 ContextHandle::Strong(strong) => Some(strong.clone()),
82 }
83 }
84
85 fn downgrade(&self) -> WeakEntity<AssistantContext> {
86 match self {
87 ContextHandle::Weak(weak) => weak.clone(),
88 ContextHandle::Strong(strong) => strong.downgrade(),
89 }
90 }
91}
92
93impl ContextStore {
94 pub fn new(
95 project: Entity<Project>,
96 prompt_builder: Arc<PromptBuilder>,
97 slash_commands: Arc<SlashCommandWorkingSet>,
98 cx: &mut App,
99 ) -> Task<Result<Entity<Self>>> {
100 let fs = project.read(cx).fs().clone();
101 let languages = project.read(cx).languages().clone();
102 let telemetry = project.read(cx).client().telemetry().clone();
103 cx.spawn(|mut cx| async move {
104 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
105 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
106
107 let this = cx.new(|cx: &mut Context<Self>| {
108 let context_server_factory_registry =
109 ContextServerFactoryRegistry::default_global(cx);
110 let context_server_manager = cx.new(|cx| {
111 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
112 });
113 let mut this = Self {
114 contexts: Vec::new(),
115 contexts_metadata: Vec::new(),
116 context_server_manager,
117 context_server_slash_command_ids: HashMap::default(),
118 host_contexts: Vec::new(),
119 fs,
120 languages,
121 slash_commands,
122 telemetry,
123 _watch_updates: cx.spawn(|this, mut cx| {
124 async move {
125 while events.next().await.is_some() {
126 this.update(&mut cx, |this, cx| this.reload(cx))?
127 .await
128 .log_err();
129 }
130 anyhow::Ok(())
131 }
132 .log_err()
133 }),
134 client_subscription: None,
135 _project_subscriptions: vec![
136 cx.observe(&project, Self::handle_project_changed),
137 cx.subscribe(&project, Self::handle_project_event),
138 ],
139 project_is_shared: false,
140 client: project.read(cx).client(),
141 project: project.clone(),
142 prompt_builder,
143 };
144 this.handle_project_changed(project.clone(), cx);
145 this.synchronize_contexts(cx);
146 this.register_context_server_handlers(cx);
147 this
148 })?;
149 this.update(&mut cx, |this, cx| this.reload(cx))?
150 .await
151 .log_err();
152
153 Ok(this)
154 })
155 }
156
157 async fn handle_advertise_contexts(
158 this: Entity<Self>,
159 envelope: TypedEnvelope<proto::AdvertiseContexts>,
160 mut cx: AsyncApp,
161 ) -> Result<()> {
162 this.update(&mut cx, |this, cx| {
163 this.host_contexts = envelope
164 .payload
165 .contexts
166 .into_iter()
167 .map(|context| RemoteContextMetadata {
168 id: ContextId::from_proto(context.context_id),
169 summary: context.summary,
170 })
171 .collect();
172 cx.notify();
173 })
174 }
175
176 async fn handle_open_context(
177 this: Entity<Self>,
178 envelope: TypedEnvelope<proto::OpenContext>,
179 mut cx: AsyncApp,
180 ) -> Result<proto::OpenContextResponse> {
181 let context_id = ContextId::from_proto(envelope.payload.context_id);
182 let operations = this.update(&mut cx, |this, cx| {
183 if this.project.read(cx).is_via_collab() {
184 return Err(anyhow!("only the host contexts can be opened"));
185 }
186
187 let context = this
188 .loaded_context_for_id(&context_id, cx)
189 .context("context not found")?;
190 if context.read(cx).replica_id() != ReplicaId::default() {
191 return Err(anyhow!("context must be opened via the host"));
192 }
193
194 anyhow::Ok(
195 context
196 .read(cx)
197 .serialize_ops(&ContextVersion::default(), cx),
198 )
199 })??;
200 let operations = operations.await;
201 Ok(proto::OpenContextResponse {
202 context: Some(proto::Context { operations }),
203 })
204 }
205
206 async fn handle_create_context(
207 this: Entity<Self>,
208 _: TypedEnvelope<proto::CreateContext>,
209 mut cx: AsyncApp,
210 ) -> Result<proto::CreateContextResponse> {
211 let (context_id, operations) = this.update(&mut cx, |this, cx| {
212 if this.project.read(cx).is_via_collab() {
213 return Err(anyhow!("can only create contexts as the host"));
214 }
215
216 let context = this.create(cx);
217 let context_id = context.read(cx).id().clone();
218 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
219
220 anyhow::Ok((
221 context_id,
222 context
223 .read(cx)
224 .serialize_ops(&ContextVersion::default(), cx),
225 ))
226 })??;
227 let operations = operations.await;
228 Ok(proto::CreateContextResponse {
229 context_id: context_id.to_proto(),
230 context: Some(proto::Context { operations }),
231 })
232 }
233
234 async fn handle_update_context(
235 this: Entity<Self>,
236 envelope: TypedEnvelope<proto::UpdateContext>,
237 mut cx: AsyncApp,
238 ) -> Result<()> {
239 this.update(&mut cx, |this, cx| {
240 let context_id = ContextId::from_proto(envelope.payload.context_id);
241 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
242 let operation_proto = envelope.payload.operation.context("invalid operation")?;
243 let operation = ContextOperation::from_proto(operation_proto)?;
244 context.update(cx, |context, cx| context.apply_ops([operation], cx));
245 }
246 Ok(())
247 })?
248 }
249
250 async fn handle_synchronize_contexts(
251 this: Entity<Self>,
252 envelope: TypedEnvelope<proto::SynchronizeContexts>,
253 mut cx: AsyncApp,
254 ) -> Result<proto::SynchronizeContextsResponse> {
255 this.update(&mut cx, |this, cx| {
256 if this.project.read(cx).is_via_collab() {
257 return Err(anyhow!("only the host can synchronize contexts"));
258 }
259
260 let mut local_versions = Vec::new();
261 for remote_version_proto in envelope.payload.contexts {
262 let remote_version = ContextVersion::from_proto(&remote_version_proto);
263 let context_id = ContextId::from_proto(remote_version_proto.context_id);
264 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
265 let context = context.read(cx);
266 let operations = context.serialize_ops(&remote_version, cx);
267 local_versions.push(context.version(cx).to_proto(context_id.clone()));
268 let client = this.client.clone();
269 let project_id = envelope.payload.project_id;
270 cx.background_executor()
271 .spawn(async move {
272 let operations = operations.await;
273 for operation in operations {
274 client.send(proto::UpdateContext {
275 project_id,
276 context_id: context_id.to_proto(),
277 operation: Some(operation),
278 })?;
279 }
280 anyhow::Ok(())
281 })
282 .detach_and_log_err(cx);
283 }
284 }
285
286 this.advertise_contexts(cx);
287
288 anyhow::Ok(proto::SynchronizeContextsResponse {
289 contexts: local_versions,
290 })
291 })?
292 }
293
294 fn handle_project_changed(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
295 let is_shared = self.project.read(cx).is_shared();
296 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
297 if is_shared == was_shared {
298 return;
299 }
300
301 if is_shared {
302 self.contexts.retain_mut(|context| {
303 if let Some(strong_context) = context.upgrade() {
304 *context = ContextHandle::Strong(strong_context);
305 true
306 } else {
307 false
308 }
309 });
310 let remote_id = self.project.read(cx).remote_id().unwrap();
311 self.client_subscription = self
312 .client
313 .subscribe_to_entity(remote_id)
314 .log_err()
315 .map(|subscription| subscription.set_model(&cx.entity(), &mut cx.to_async()));
316 self.advertise_contexts(cx);
317 } else {
318 self.client_subscription = None;
319 }
320 }
321
322 fn handle_project_event(
323 &mut self,
324 _: Entity<Project>,
325 event: &project::Event,
326 cx: &mut Context<Self>,
327 ) {
328 match event {
329 project::Event::Reshared => {
330 self.advertise_contexts(cx);
331 }
332 project::Event::HostReshared | project::Event::Rejoined => {
333 self.synchronize_contexts(cx);
334 }
335 project::Event::DisconnectedFromHost => {
336 self.contexts.retain_mut(|context| {
337 if let Some(strong_context) = context.upgrade() {
338 *context = ContextHandle::Weak(context.downgrade());
339 strong_context.update(cx, |context, cx| {
340 if context.replica_id() != ReplicaId::default() {
341 context.set_capability(language::Capability::ReadOnly, cx);
342 }
343 });
344 true
345 } else {
346 false
347 }
348 });
349 self.host_contexts.clear();
350 cx.notify();
351 }
352 _ => {}
353 }
354 }
355
356 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
357 let context = cx.new(|cx| {
358 AssistantContext::local(
359 self.languages.clone(),
360 Some(self.project.clone()),
361 Some(self.telemetry.clone()),
362 self.prompt_builder.clone(),
363 self.slash_commands.clone(),
364 cx,
365 )
366 });
367 self.register_context(&context, cx);
368 context
369 }
370
371 pub fn create_remote_context(
372 &mut self,
373 cx: &mut Context<Self>,
374 ) -> Task<Result<Entity<AssistantContext>>> {
375 let project = self.project.read(cx);
376 let Some(project_id) = project.remote_id() else {
377 return Task::ready(Err(anyhow!("project was not remote")));
378 };
379
380 let replica_id = project.replica_id();
381 let capability = project.capability();
382 let language_registry = self.languages.clone();
383 let project = self.project.clone();
384 let telemetry = self.telemetry.clone();
385 let prompt_builder = self.prompt_builder.clone();
386 let slash_commands = self.slash_commands.clone();
387 let request = self.client.request(proto::CreateContext { project_id });
388 cx.spawn(|this, mut cx| async move {
389 let response = request.await?;
390 let context_id = ContextId::from_proto(response.context_id);
391 let context_proto = response.context.context("invalid context")?;
392 let context = cx.new(|cx| {
393 AssistantContext::new(
394 context_id.clone(),
395 replica_id,
396 capability,
397 language_registry,
398 prompt_builder,
399 slash_commands,
400 Some(project),
401 Some(telemetry),
402 cx,
403 )
404 })?;
405 let operations = cx
406 .background_executor()
407 .spawn(async move {
408 context_proto
409 .operations
410 .into_iter()
411 .map(ContextOperation::from_proto)
412 .collect::<Result<Vec<_>>>()
413 })
414 .await?;
415 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
416 this.update(&mut cx, |this, cx| {
417 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
418 existing_context
419 } else {
420 this.register_context(&context, cx);
421 this.synchronize_contexts(cx);
422 context
423 }
424 })
425 })
426 }
427
428 pub fn open_local_context(
429 &mut self,
430 path: PathBuf,
431 cx: &Context<Self>,
432 ) -> Task<Result<Entity<AssistantContext>>> {
433 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
434 return Task::ready(Ok(existing_context));
435 }
436
437 let fs = self.fs.clone();
438 let languages = self.languages.clone();
439 let project = self.project.clone();
440 let telemetry = self.telemetry.clone();
441 let load = cx.background_executor().spawn({
442 let path = path.clone();
443 async move {
444 let saved_context = fs.load(&path).await?;
445 SavedContext::from_json(&saved_context)
446 }
447 });
448 let prompt_builder = self.prompt_builder.clone();
449 let slash_commands = self.slash_commands.clone();
450
451 cx.spawn(|this, mut cx| async move {
452 let saved_context = load.await?;
453 let context = cx.new(|cx| {
454 AssistantContext::deserialize(
455 saved_context,
456 path.clone(),
457 languages,
458 prompt_builder,
459 slash_commands,
460 Some(project),
461 Some(telemetry),
462 cx,
463 )
464 })?;
465 this.update(&mut cx, |this, cx| {
466 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
467 existing_context
468 } else {
469 this.register_context(&context, cx);
470 context
471 }
472 })
473 })
474 }
475
476 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
477 self.contexts.iter().find_map(|context| {
478 let context = context.upgrade()?;
479 if context.read(cx).path() == Some(path) {
480 Some(context)
481 } else {
482 None
483 }
484 })
485 }
486
487 pub fn loaded_context_for_id(
488 &self,
489 id: &ContextId,
490 cx: &App,
491 ) -> Option<Entity<AssistantContext>> {
492 self.contexts.iter().find_map(|context| {
493 let context = context.upgrade()?;
494 if context.read(cx).id() == id {
495 Some(context)
496 } else {
497 None
498 }
499 })
500 }
501
502 pub fn open_remote_context(
503 &mut self,
504 context_id: ContextId,
505 cx: &mut Context<Self>,
506 ) -> Task<Result<Entity<AssistantContext>>> {
507 let project = self.project.read(cx);
508 let Some(project_id) = project.remote_id() else {
509 return Task::ready(Err(anyhow!("project was not remote")));
510 };
511
512 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
513 return Task::ready(Ok(context));
514 }
515
516 let replica_id = project.replica_id();
517 let capability = project.capability();
518 let language_registry = self.languages.clone();
519 let project = self.project.clone();
520 let telemetry = self.telemetry.clone();
521 let request = self.client.request(proto::OpenContext {
522 project_id,
523 context_id: context_id.to_proto(),
524 });
525 let prompt_builder = self.prompt_builder.clone();
526 let slash_commands = self.slash_commands.clone();
527 cx.spawn(|this, mut cx| async move {
528 let response = request.await?;
529 let context_proto = response.context.context("invalid context")?;
530 let context = cx.new(|cx| {
531 AssistantContext::new(
532 context_id.clone(),
533 replica_id,
534 capability,
535 language_registry,
536 prompt_builder,
537 slash_commands,
538 Some(project),
539 Some(telemetry),
540 cx,
541 )
542 })?;
543 let operations = cx
544 .background_executor()
545 .spawn(async move {
546 context_proto
547 .operations
548 .into_iter()
549 .map(ContextOperation::from_proto)
550 .collect::<Result<Vec<_>>>()
551 })
552 .await?;
553 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
554 this.update(&mut cx, |this, cx| {
555 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
556 existing_context
557 } else {
558 this.register_context(&context, cx);
559 this.synchronize_contexts(cx);
560 context
561 }
562 })
563 })
564 }
565
566 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
567 let handle = if self.project_is_shared {
568 ContextHandle::Strong(context.clone())
569 } else {
570 ContextHandle::Weak(context.downgrade())
571 };
572 self.contexts.push(handle);
573 self.advertise_contexts(cx);
574 cx.subscribe(context, Self::handle_context_event).detach();
575 }
576
577 fn handle_context_event(
578 &mut self,
579 context: Entity<AssistantContext>,
580 event: &ContextEvent,
581 cx: &mut Context<Self>,
582 ) {
583 let Some(project_id) = self.project.read(cx).remote_id() else {
584 return;
585 };
586
587 match event {
588 ContextEvent::SummaryChanged => {
589 self.advertise_contexts(cx);
590 }
591 ContextEvent::Operation(operation) => {
592 let context_id = context.read(cx).id().to_proto();
593 let operation = operation.to_proto();
594 self.client
595 .send(proto::UpdateContext {
596 project_id,
597 context_id,
598 operation: Some(operation),
599 })
600 .log_err();
601 }
602 _ => {}
603 }
604 }
605
606 fn advertise_contexts(&self, cx: &App) {
607 let Some(project_id) = self.project.read(cx).remote_id() else {
608 return;
609 };
610
611 // For now, only the host can advertise their open contexts.
612 if self.project.read(cx).is_via_collab() {
613 return;
614 }
615
616 let contexts = self
617 .contexts
618 .iter()
619 .rev()
620 .filter_map(|context| {
621 let context = context.upgrade()?.read(cx);
622 if context.replica_id() == ReplicaId::default() {
623 Some(proto::ContextMetadata {
624 context_id: context.id().to_proto(),
625 summary: context.summary().map(|summary| summary.text.clone()),
626 })
627 } else {
628 None
629 }
630 })
631 .collect();
632 self.client
633 .send(proto::AdvertiseContexts {
634 project_id,
635 contexts,
636 })
637 .ok();
638 }
639
640 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
641 let Some(project_id) = self.project.read(cx).remote_id() else {
642 return;
643 };
644
645 let contexts = self
646 .contexts
647 .iter()
648 .filter_map(|context| {
649 let context = context.upgrade()?.read(cx);
650 if context.replica_id() != ReplicaId::default() {
651 Some(context.version(cx).to_proto(context.id().clone()))
652 } else {
653 None
654 }
655 })
656 .collect();
657
658 let client = self.client.clone();
659 let request = self.client.request(proto::SynchronizeContexts {
660 project_id,
661 contexts,
662 });
663 cx.spawn(|this, cx| async move {
664 let response = request.await?;
665
666 let mut context_ids = Vec::new();
667 let mut operations = Vec::new();
668 this.read_with(&cx, |this, cx| {
669 for context_version_proto in response.contexts {
670 let context_version = ContextVersion::from_proto(&context_version_proto);
671 let context_id = ContextId::from_proto(context_version_proto.context_id);
672 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
673 context_ids.push(context_id);
674 operations.push(context.read(cx).serialize_ops(&context_version, cx));
675 }
676 }
677 })?;
678
679 let operations = futures::future::join_all(operations).await;
680 for (context_id, operations) in context_ids.into_iter().zip(operations) {
681 for operation in operations {
682 client.send(proto::UpdateContext {
683 project_id,
684 context_id: context_id.to_proto(),
685 operation: Some(operation),
686 })?;
687 }
688 }
689
690 anyhow::Ok(())
691 })
692 .detach_and_log_err(cx);
693 }
694
695 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
696 let metadata = self.contexts_metadata.clone();
697 let executor = cx.background_executor().clone();
698 cx.background_executor().spawn(async move {
699 if query.is_empty() {
700 metadata
701 } else {
702 let candidates = metadata
703 .iter()
704 .enumerate()
705 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
706 .collect::<Vec<_>>();
707 let matches = fuzzy::match_strings(
708 &candidates,
709 &query,
710 false,
711 100,
712 &Default::default(),
713 executor,
714 )
715 .await;
716
717 matches
718 .into_iter()
719 .map(|mat| metadata[mat.candidate_id].clone())
720 .collect()
721 }
722 })
723 }
724
725 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
726 &self.host_contexts
727 }
728
729 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
730 let fs = self.fs.clone();
731 cx.spawn(|this, mut cx| async move {
732 fs.create_dir(contexts_dir()).await?;
733
734 let mut paths = fs.read_dir(contexts_dir()).await?;
735 let mut contexts = Vec::<SavedContextMetadata>::new();
736 while let Some(path) = paths.next().await {
737 let path = path?;
738 if path.extension() != Some(OsStr::new("json")) {
739 continue;
740 }
741
742 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
743 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
744
745 let metadata = fs.metadata(&path).await?;
746 if let Some((file_name, metadata)) = path
747 .file_name()
748 .and_then(|name| name.to_str())
749 .zip(metadata)
750 {
751 // This is used to filter out contexts saved by the new assistant.
752 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
753 continue;
754 }
755
756 if let Some(title) = ASSISTANT_CONTEXT_REGEX
757 .replace(file_name, "")
758 .lines()
759 .next()
760 {
761 contexts.push(SavedContextMetadata {
762 title: title.to_string(),
763 path,
764 mtime: metadata.mtime.timestamp_for_user().into(),
765 });
766 }
767 }
768 }
769 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
770
771 this.update(&mut cx, |this, cx| {
772 this.contexts_metadata = contexts;
773 cx.notify();
774 })
775 })
776 }
777
778 pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
779 cx.update_entity(
780 &self.context_server_manager,
781 |context_server_manager, cx| {
782 for server in context_server_manager.servers() {
783 context_server_manager
784 .restart_server(&server.id(), cx)
785 .detach_and_log_err(cx);
786 }
787 },
788 );
789 }
790
791 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
792 cx.subscribe(
793 &self.context_server_manager.clone(),
794 Self::handle_context_server_event,
795 )
796 .detach();
797 }
798
799 fn handle_context_server_event(
800 &mut self,
801 context_server_manager: Entity<ContextServerManager>,
802 event: &context_server::manager::Event,
803 cx: &mut Context<Self>,
804 ) {
805 let slash_command_working_set = self.slash_commands.clone();
806 match event {
807 context_server::manager::Event::ServerStarted { server_id } => {
808 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
809 let context_server_manager = context_server_manager.clone();
810 cx.spawn({
811 let server = server.clone();
812 let server_id = server_id.clone();
813 |this, mut cx| async move {
814 let Some(protocol) = server.client() else {
815 return;
816 };
817
818 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
819 if let Some(prompts) = protocol.list_prompts().await.log_err() {
820 let slash_command_ids = prompts
821 .into_iter()
822 .filter(assistant_slash_commands::acceptable_prompt)
823 .map(|prompt| {
824 log::info!(
825 "registering context server command: {:?}",
826 prompt.name
827 );
828 slash_command_working_set.insert(Arc::new(
829 assistant_slash_commands::ContextServerSlashCommand::new(
830 context_server_manager.clone(),
831 &server,
832 prompt,
833 ),
834 ))
835 })
836 .collect::<Vec<_>>();
837
838 this.update(&mut cx, |this, _cx| {
839 this.context_server_slash_command_ids
840 .insert(server_id.clone(), slash_command_ids);
841 })
842 .log_err();
843 }
844 }
845 }
846 })
847 .detach();
848 }
849 }
850 context_server::manager::Event::ServerStopped { server_id } => {
851 if let Some(slash_command_ids) =
852 self.context_server_slash_command_ids.remove(server_id)
853 {
854 slash_command_working_set.remove(&slash_command_ids);
855 }
856 }
857 }
858 }
859}