1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::ContextServerFactoryRegistry;
12use fs::{Fs, RemoveOptions};
13use futures::StreamExt;
14use fuzzy::StringMatchCandidate;
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
16use language::LanguageRegistry;
17use paths::contexts_dir;
18use project::Project;
19use prompt_store::PromptBuilder;
20use regex::Regex;
21use rpc::AnyProtoClient;
22use std::sync::LazyLock;
23use std::{
24 cmp::Reverse,
25 ffi::OsStr,
26 mem,
27 path::{Path, PathBuf},
28 sync::Arc,
29 time::Duration,
30};
31use util::{ResultExt, TryFutureExt};
32
33pub(crate) fn init(client: &AnyProtoClient) {
34 client.add_entity_message_handler(ContextStore::handle_advertise_contexts);
35 client.add_entity_request_handler(ContextStore::handle_open_context);
36 client.add_entity_request_handler(ContextStore::handle_create_context);
37 client.add_entity_message_handler(ContextStore::handle_update_context);
38 client.add_entity_request_handler(ContextStore::handle_synchronize_contexts);
39}
40
41#[derive(Clone)]
42pub struct RemoteContextMetadata {
43 pub id: ContextId,
44 pub summary: Option<String>,
45}
46
47pub struct ContextStore {
48 contexts: Vec<ContextHandle>,
49 contexts_metadata: Vec<SavedContextMetadata>,
50 context_server_manager: Entity<ContextServerManager>,
51 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
52 host_contexts: Vec<RemoteContextMetadata>,
53 fs: Arc<dyn Fs>,
54 languages: Arc<LanguageRegistry>,
55 slash_commands: Arc<SlashCommandWorkingSet>,
56 telemetry: Arc<Telemetry>,
57 _watch_updates: Task<Option<()>>,
58 client: Arc<Client>,
59 project: Entity<Project>,
60 project_is_shared: bool,
61 client_subscription: Option<client::Subscription>,
62 _project_subscriptions: Vec<gpui::Subscription>,
63 prompt_builder: Arc<PromptBuilder>,
64}
65
66pub enum ContextStoreEvent {
67 ContextCreated(ContextId),
68}
69
70impl EventEmitter<ContextStoreEvent> for ContextStore {}
71
72enum ContextHandle {
73 Weak(WeakEntity<AssistantContext>),
74 Strong(Entity<AssistantContext>),
75}
76
77impl ContextHandle {
78 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
79 match self {
80 ContextHandle::Weak(weak) => weak.upgrade(),
81 ContextHandle::Strong(strong) => Some(strong.clone()),
82 }
83 }
84
85 fn downgrade(&self) -> WeakEntity<AssistantContext> {
86 match self {
87 ContextHandle::Weak(weak) => weak.clone(),
88 ContextHandle::Strong(strong) => strong.downgrade(),
89 }
90 }
91}
92
93impl ContextStore {
94 pub fn new(
95 project: Entity<Project>,
96 prompt_builder: Arc<PromptBuilder>,
97 slash_commands: Arc<SlashCommandWorkingSet>,
98 cx: &mut App,
99 ) -> Task<Result<Entity<Self>>> {
100 let fs = project.read(cx).fs().clone();
101 let languages = project.read(cx).languages().clone();
102 let telemetry = project.read(cx).client().telemetry().clone();
103 cx.spawn(async move |cx| {
104 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
105 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
106
107 let this =
108 cx.new(|cx: &mut Context<Self>| {
109 let context_server_factory_registry =
110 ContextServerFactoryRegistry::default_global(cx);
111 let context_server_manager = cx.new(|cx| {
112 ContextServerManager::new(
113 context_server_factory_registry,
114 project.clone(),
115 cx,
116 )
117 });
118 let mut this = Self {
119 contexts: Vec::new(),
120 contexts_metadata: Vec::new(),
121 context_server_manager,
122 context_server_slash_command_ids: HashMap::default(),
123 host_contexts: Vec::new(),
124 fs,
125 languages,
126 slash_commands,
127 telemetry,
128 _watch_updates: cx.spawn(async move |this, cx| {
129 async move {
130 while events.next().await.is_some() {
131 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
132 }
133 anyhow::Ok(())
134 }
135 .log_err()
136 .await
137 }),
138 client_subscription: None,
139 _project_subscriptions: vec![
140 cx.subscribe(&project, Self::handle_project_event)
141 ],
142 project_is_shared: false,
143 client: project.read(cx).client(),
144 project: project.clone(),
145 prompt_builder,
146 };
147 this.handle_project_shared(project.clone(), cx);
148 this.synchronize_contexts(cx);
149 this.register_context_server_handlers(cx);
150 this.reload(cx).detach_and_log_err(cx);
151 this
152 })?;
153
154 Ok(this)
155 })
156 }
157
158 async fn handle_advertise_contexts(
159 this: Entity<Self>,
160 envelope: TypedEnvelope<proto::AdvertiseContexts>,
161 mut cx: AsyncApp,
162 ) -> Result<()> {
163 this.update(&mut cx, |this, cx| {
164 this.host_contexts = envelope
165 .payload
166 .contexts
167 .into_iter()
168 .map(|context| RemoteContextMetadata {
169 id: ContextId::from_proto(context.context_id),
170 summary: context.summary,
171 })
172 .collect();
173 cx.notify();
174 })
175 }
176
177 async fn handle_open_context(
178 this: Entity<Self>,
179 envelope: TypedEnvelope<proto::OpenContext>,
180 mut cx: AsyncApp,
181 ) -> Result<proto::OpenContextResponse> {
182 let context_id = ContextId::from_proto(envelope.payload.context_id);
183 let operations = this.update(&mut cx, |this, cx| {
184 if this.project.read(cx).is_via_collab() {
185 return Err(anyhow!("only the host contexts can be opened"));
186 }
187
188 let context = this
189 .loaded_context_for_id(&context_id, cx)
190 .context("context not found")?;
191 if context.read(cx).replica_id() != ReplicaId::default() {
192 return Err(anyhow!("context must be opened via the host"));
193 }
194
195 anyhow::Ok(
196 context
197 .read(cx)
198 .serialize_ops(&ContextVersion::default(), cx),
199 )
200 })??;
201 let operations = operations.await;
202 Ok(proto::OpenContextResponse {
203 context: Some(proto::Context { operations }),
204 })
205 }
206
207 async fn handle_create_context(
208 this: Entity<Self>,
209 _: TypedEnvelope<proto::CreateContext>,
210 mut cx: AsyncApp,
211 ) -> Result<proto::CreateContextResponse> {
212 let (context_id, operations) = this.update(&mut cx, |this, cx| {
213 if this.project.read(cx).is_via_collab() {
214 return Err(anyhow!("can only create contexts as the host"));
215 }
216
217 let context = this.create(cx);
218 let context_id = context.read(cx).id().clone();
219 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
220
221 anyhow::Ok((
222 context_id,
223 context
224 .read(cx)
225 .serialize_ops(&ContextVersion::default(), cx),
226 ))
227 })??;
228 let operations = operations.await;
229 Ok(proto::CreateContextResponse {
230 context_id: context_id.to_proto(),
231 context: Some(proto::Context { operations }),
232 })
233 }
234
235 async fn handle_update_context(
236 this: Entity<Self>,
237 envelope: TypedEnvelope<proto::UpdateContext>,
238 mut cx: AsyncApp,
239 ) -> Result<()> {
240 this.update(&mut cx, |this, cx| {
241 let context_id = ContextId::from_proto(envelope.payload.context_id);
242 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
243 let operation_proto = envelope.payload.operation.context("invalid operation")?;
244 let operation = ContextOperation::from_proto(operation_proto)?;
245 context.update(cx, |context, cx| context.apply_ops([operation], cx));
246 }
247 Ok(())
248 })?
249 }
250
251 async fn handle_synchronize_contexts(
252 this: Entity<Self>,
253 envelope: TypedEnvelope<proto::SynchronizeContexts>,
254 mut cx: AsyncApp,
255 ) -> Result<proto::SynchronizeContextsResponse> {
256 this.update(&mut cx, |this, cx| {
257 if this.project.read(cx).is_via_collab() {
258 return Err(anyhow!("only the host can synchronize contexts"));
259 }
260
261 let mut local_versions = Vec::new();
262 for remote_version_proto in envelope.payload.contexts {
263 let remote_version = ContextVersion::from_proto(&remote_version_proto);
264 let context_id = ContextId::from_proto(remote_version_proto.context_id);
265 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
266 let context = context.read(cx);
267 let operations = context.serialize_ops(&remote_version, cx);
268 local_versions.push(context.version(cx).to_proto(context_id.clone()));
269 let client = this.client.clone();
270 let project_id = envelope.payload.project_id;
271 cx.background_spawn(async move {
272 let operations = operations.await;
273 for operation in operations {
274 client.send(proto::UpdateContext {
275 project_id,
276 context_id: context_id.to_proto(),
277 operation: Some(operation),
278 })?;
279 }
280 anyhow::Ok(())
281 })
282 .detach_and_log_err(cx);
283 }
284 }
285
286 this.advertise_contexts(cx);
287
288 anyhow::Ok(proto::SynchronizeContextsResponse {
289 contexts: local_versions,
290 })
291 })?
292 }
293
294 fn handle_project_shared(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
295 let is_shared = self.project.read(cx).is_shared();
296 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
297 if is_shared == was_shared {
298 return;
299 }
300
301 if is_shared {
302 self.contexts.retain_mut(|context| {
303 if let Some(strong_context) = context.upgrade() {
304 *context = ContextHandle::Strong(strong_context);
305 true
306 } else {
307 false
308 }
309 });
310 let remote_id = self.project.read(cx).remote_id().unwrap();
311 self.client_subscription = self
312 .client
313 .subscribe_to_entity(remote_id)
314 .log_err()
315 .map(|subscription| subscription.set_entity(&cx.entity(), &mut cx.to_async()));
316 self.advertise_contexts(cx);
317 } else {
318 self.client_subscription = None;
319 }
320 }
321
322 fn handle_project_event(
323 &mut self,
324 project: Entity<Project>,
325 event: &project::Event,
326 cx: &mut Context<Self>,
327 ) {
328 match event {
329 project::Event::RemoteIdChanged(_) => {
330 self.handle_project_shared(project, cx);
331 }
332 project::Event::Reshared => {
333 self.advertise_contexts(cx);
334 }
335 project::Event::HostReshared | project::Event::Rejoined => {
336 self.synchronize_contexts(cx);
337 }
338 project::Event::DisconnectedFromHost => {
339 self.contexts.retain_mut(|context| {
340 if let Some(strong_context) = context.upgrade() {
341 *context = ContextHandle::Weak(context.downgrade());
342 strong_context.update(cx, |context, cx| {
343 if context.replica_id() != ReplicaId::default() {
344 context.set_capability(language::Capability::ReadOnly, cx);
345 }
346 });
347 true
348 } else {
349 false
350 }
351 });
352 self.host_contexts.clear();
353 cx.notify();
354 }
355 _ => {}
356 }
357 }
358
359 pub fn contexts(&self) -> Vec<SavedContextMetadata> {
360 let mut contexts = self.contexts_metadata.iter().cloned().collect::<Vec<_>>();
361 contexts.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.mtime));
362 contexts
363 }
364
365 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
366 let context = cx.new(|cx| {
367 AssistantContext::local(
368 self.languages.clone(),
369 Some(self.project.clone()),
370 Some(self.telemetry.clone()),
371 self.prompt_builder.clone(),
372 self.slash_commands.clone(),
373 cx,
374 )
375 });
376 self.register_context(&context, cx);
377 context
378 }
379
380 pub fn create_remote_context(
381 &mut self,
382 cx: &mut Context<Self>,
383 ) -> Task<Result<Entity<AssistantContext>>> {
384 let project = self.project.read(cx);
385 let Some(project_id) = project.remote_id() else {
386 return Task::ready(Err(anyhow!("project was not remote")));
387 };
388
389 let replica_id = project.replica_id();
390 let capability = project.capability();
391 let language_registry = self.languages.clone();
392 let project = self.project.clone();
393 let telemetry = self.telemetry.clone();
394 let prompt_builder = self.prompt_builder.clone();
395 let slash_commands = self.slash_commands.clone();
396 let request = self.client.request(proto::CreateContext { project_id });
397 cx.spawn(async move |this, cx| {
398 let response = request.await?;
399 let context_id = ContextId::from_proto(response.context_id);
400 let context_proto = response.context.context("invalid context")?;
401 let context = cx.new(|cx| {
402 AssistantContext::new(
403 context_id.clone(),
404 replica_id,
405 capability,
406 language_registry,
407 prompt_builder,
408 slash_commands,
409 Some(project),
410 Some(telemetry),
411 cx,
412 )
413 })?;
414 let operations = cx
415 .background_spawn(async move {
416 context_proto
417 .operations
418 .into_iter()
419 .map(ContextOperation::from_proto)
420 .collect::<Result<Vec<_>>>()
421 })
422 .await?;
423 context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
424 this.update(cx, |this, cx| {
425 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
426 existing_context
427 } else {
428 this.register_context(&context, cx);
429 this.synchronize_contexts(cx);
430 context
431 }
432 })
433 })
434 }
435
436 pub fn open_local_context(
437 &mut self,
438 path: PathBuf,
439 cx: &Context<Self>,
440 ) -> Task<Result<Entity<AssistantContext>>> {
441 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
442 return Task::ready(Ok(existing_context));
443 }
444
445 let fs = self.fs.clone();
446 let languages = self.languages.clone();
447 let project = self.project.clone();
448 let telemetry = self.telemetry.clone();
449 let load = cx.background_spawn({
450 let path = path.clone();
451 async move {
452 let saved_context = fs.load(&path).await?;
453 SavedContext::from_json(&saved_context)
454 }
455 });
456 let prompt_builder = self.prompt_builder.clone();
457 let slash_commands = self.slash_commands.clone();
458
459 cx.spawn(async move |this, cx| {
460 let saved_context = load.await?;
461 let context = cx.new(|cx| {
462 AssistantContext::deserialize(
463 saved_context,
464 path.clone(),
465 languages,
466 prompt_builder,
467 slash_commands,
468 Some(project),
469 Some(telemetry),
470 cx,
471 )
472 })?;
473 this.update(cx, |this, cx| {
474 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
475 existing_context
476 } else {
477 this.register_context(&context, cx);
478 context
479 }
480 })
481 })
482 }
483
484 pub fn delete_local_context(
485 &mut self,
486 path: PathBuf,
487 cx: &mut Context<Self>,
488 ) -> Task<Result<()>> {
489 let fs = self.fs.clone();
490
491 cx.spawn(async move |this, cx| {
492 fs.remove_file(
493 &path,
494 RemoveOptions {
495 recursive: false,
496 ignore_if_not_exists: true,
497 },
498 )
499 .await?;
500
501 this.update(cx, |this, cx| {
502 this.contexts.retain(|context| {
503 context
504 .upgrade()
505 .and_then(|context| context.read(cx).path())
506 != Some(&path)
507 });
508 this.contexts_metadata
509 .retain(|context| context.path != path);
510 })?;
511
512 Ok(())
513 })
514 }
515
516 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
517 self.contexts.iter().find_map(|context| {
518 let context = context.upgrade()?;
519 if context.read(cx).path() == Some(path) {
520 Some(context)
521 } else {
522 None
523 }
524 })
525 }
526
527 pub fn loaded_context_for_id(
528 &self,
529 id: &ContextId,
530 cx: &App,
531 ) -> Option<Entity<AssistantContext>> {
532 self.contexts.iter().find_map(|context| {
533 let context = context.upgrade()?;
534 if context.read(cx).id() == id {
535 Some(context)
536 } else {
537 None
538 }
539 })
540 }
541
542 pub fn open_remote_context(
543 &mut self,
544 context_id: ContextId,
545 cx: &mut Context<Self>,
546 ) -> Task<Result<Entity<AssistantContext>>> {
547 let project = self.project.read(cx);
548 let Some(project_id) = project.remote_id() else {
549 return Task::ready(Err(anyhow!("project was not remote")));
550 };
551
552 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
553 return Task::ready(Ok(context));
554 }
555
556 let replica_id = project.replica_id();
557 let capability = project.capability();
558 let language_registry = self.languages.clone();
559 let project = self.project.clone();
560 let telemetry = self.telemetry.clone();
561 let request = self.client.request(proto::OpenContext {
562 project_id,
563 context_id: context_id.to_proto(),
564 });
565 let prompt_builder = self.prompt_builder.clone();
566 let slash_commands = self.slash_commands.clone();
567 cx.spawn(async move |this, cx| {
568 let response = request.await?;
569 let context_proto = response.context.context("invalid context")?;
570 let context = cx.new(|cx| {
571 AssistantContext::new(
572 context_id.clone(),
573 replica_id,
574 capability,
575 language_registry,
576 prompt_builder,
577 slash_commands,
578 Some(project),
579 Some(telemetry),
580 cx,
581 )
582 })?;
583 let operations = cx
584 .background_spawn(async move {
585 context_proto
586 .operations
587 .into_iter()
588 .map(ContextOperation::from_proto)
589 .collect::<Result<Vec<_>>>()
590 })
591 .await?;
592 context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
593 this.update(cx, |this, cx| {
594 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
595 existing_context
596 } else {
597 this.register_context(&context, cx);
598 this.synchronize_contexts(cx);
599 context
600 }
601 })
602 })
603 }
604
605 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
606 let handle = if self.project_is_shared {
607 ContextHandle::Strong(context.clone())
608 } else {
609 ContextHandle::Weak(context.downgrade())
610 };
611 self.contexts.push(handle);
612 self.advertise_contexts(cx);
613 cx.subscribe(context, Self::handle_context_event).detach();
614 }
615
616 fn handle_context_event(
617 &mut self,
618 context: Entity<AssistantContext>,
619 event: &ContextEvent,
620 cx: &mut Context<Self>,
621 ) {
622 let Some(project_id) = self.project.read(cx).remote_id() else {
623 return;
624 };
625
626 match event {
627 ContextEvent::SummaryChanged => {
628 self.advertise_contexts(cx);
629 }
630 ContextEvent::Operation(operation) => {
631 let context_id = context.read(cx).id().to_proto();
632 let operation = operation.to_proto();
633 self.client
634 .send(proto::UpdateContext {
635 project_id,
636 context_id,
637 operation: Some(operation),
638 })
639 .log_err();
640 }
641 _ => {}
642 }
643 }
644
645 fn advertise_contexts(&self, cx: &App) {
646 let Some(project_id) = self.project.read(cx).remote_id() else {
647 return;
648 };
649
650 // For now, only the host can advertise their open contexts.
651 if self.project.read(cx).is_via_collab() {
652 return;
653 }
654
655 let contexts = self
656 .contexts
657 .iter()
658 .rev()
659 .filter_map(|context| {
660 let context = context.upgrade()?.read(cx);
661 if context.replica_id() == ReplicaId::default() {
662 Some(proto::ContextMetadata {
663 context_id: context.id().to_proto(),
664 summary: context.summary().map(|summary| summary.text.clone()),
665 })
666 } else {
667 None
668 }
669 })
670 .collect();
671 self.client
672 .send(proto::AdvertiseContexts {
673 project_id,
674 contexts,
675 })
676 .ok();
677 }
678
679 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
680 let Some(project_id) = self.project.read(cx).remote_id() else {
681 return;
682 };
683
684 let contexts = self
685 .contexts
686 .iter()
687 .filter_map(|context| {
688 let context = context.upgrade()?.read(cx);
689 if context.replica_id() != ReplicaId::default() {
690 Some(context.version(cx).to_proto(context.id().clone()))
691 } else {
692 None
693 }
694 })
695 .collect();
696
697 let client = self.client.clone();
698 let request = self.client.request(proto::SynchronizeContexts {
699 project_id,
700 contexts,
701 });
702 cx.spawn(async move |this, cx| {
703 let response = request.await?;
704
705 let mut context_ids = Vec::new();
706 let mut operations = Vec::new();
707 this.read_with(cx, |this, cx| {
708 for context_version_proto in response.contexts {
709 let context_version = ContextVersion::from_proto(&context_version_proto);
710 let context_id = ContextId::from_proto(context_version_proto.context_id);
711 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
712 context_ids.push(context_id);
713 operations.push(context.read(cx).serialize_ops(&context_version, cx));
714 }
715 }
716 })?;
717
718 let operations = futures::future::join_all(operations).await;
719 for (context_id, operations) in context_ids.into_iter().zip(operations) {
720 for operation in operations {
721 client.send(proto::UpdateContext {
722 project_id,
723 context_id: context_id.to_proto(),
724 operation: Some(operation),
725 })?;
726 }
727 }
728
729 anyhow::Ok(())
730 })
731 .detach_and_log_err(cx);
732 }
733
734 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
735 let metadata = self.contexts_metadata.clone();
736 let executor = cx.background_executor().clone();
737 cx.background_spawn(async move {
738 if query.is_empty() {
739 metadata
740 } else {
741 let candidates = metadata
742 .iter()
743 .enumerate()
744 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
745 .collect::<Vec<_>>();
746 let matches = fuzzy::match_strings(
747 &candidates,
748 &query,
749 false,
750 100,
751 &Default::default(),
752 executor,
753 )
754 .await;
755
756 matches
757 .into_iter()
758 .map(|mat| metadata[mat.candidate_id].clone())
759 .collect()
760 }
761 })
762 }
763
764 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
765 &self.host_contexts
766 }
767
768 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
769 let fs = self.fs.clone();
770 cx.spawn(async move |this, cx| {
771 fs.create_dir(contexts_dir()).await?;
772
773 let mut paths = fs.read_dir(contexts_dir()).await?;
774 let mut contexts = Vec::<SavedContextMetadata>::new();
775 while let Some(path) = paths.next().await {
776 let path = path?;
777 if path.extension() != Some(OsStr::new("json")) {
778 continue;
779 }
780
781 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
782 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
783
784 let metadata = fs.metadata(&path).await?;
785 if let Some((file_name, metadata)) = path
786 .file_name()
787 .and_then(|name| name.to_str())
788 .zip(metadata)
789 {
790 // This is used to filter out contexts saved by the new assistant.
791 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
792 continue;
793 }
794
795 if let Some(title) = ASSISTANT_CONTEXT_REGEX
796 .replace(file_name, "")
797 .lines()
798 .next()
799 {
800 contexts.push(SavedContextMetadata {
801 title: title.to_string(),
802 path,
803 mtime: metadata.mtime.timestamp_for_user().into(),
804 });
805 }
806 }
807 }
808 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
809
810 this.update(cx, |this, cx| {
811 this.contexts_metadata = contexts;
812 cx.notify();
813 })
814 })
815 }
816
817 pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
818 cx.update_entity(
819 &self.context_server_manager,
820 |context_server_manager, cx| {
821 for server in context_server_manager.running_servers() {
822 context_server_manager
823 .restart_server(&server.id(), cx)
824 .detach_and_log_err(cx);
825 }
826 },
827 );
828 }
829
830 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
831 cx.subscribe(
832 &self.context_server_manager.clone(),
833 Self::handle_context_server_event,
834 )
835 .detach();
836 }
837
838 fn handle_context_server_event(
839 &mut self,
840 context_server_manager: Entity<ContextServerManager>,
841 event: &context_server::manager::Event,
842 cx: &mut Context<Self>,
843 ) {
844 let slash_command_working_set = self.slash_commands.clone();
845 match event {
846 context_server::manager::Event::ServerStarted { server_id } => {
847 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
848 let context_server_manager = context_server_manager.clone();
849 cx.spawn({
850 let server = server.clone();
851 let server_id = server_id.clone();
852 async move |this, cx| {
853 let Some(protocol) = server.client() else {
854 return;
855 };
856
857 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
858 if let Some(prompts) = protocol.list_prompts().await.log_err() {
859 let slash_command_ids = prompts
860 .into_iter()
861 .filter(assistant_slash_commands::acceptable_prompt)
862 .map(|prompt| {
863 log::info!(
864 "registering context server command: {:?}",
865 prompt.name
866 );
867 slash_command_working_set.insert(Arc::new(
868 assistant_slash_commands::ContextServerSlashCommand::new(
869 context_server_manager.clone(),
870 &server,
871 prompt,
872 ),
873 ))
874 })
875 .collect::<Vec<_>>();
876
877 this.update( cx, |this, _cx| {
878 this.context_server_slash_command_ids
879 .insert(server_id.clone(), slash_command_ids);
880 })
881 .log_err();
882 }
883 }
884 }
885 })
886 .detach();
887 }
888 }
889 context_server::manager::Event::ServerStopped { server_id } => {
890 if let Some(slash_command_ids) =
891 self.context_server_slash_command_ids.remove(server_id)
892 {
893 slash_command_working_set.remove(&slash_command_ids);
894 }
895 }
896 }
897 }
898}