1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
15use language::LanguageRegistry;
16use paths::contexts_dir;
17use project::{
18 Project,
19 context_server_store::{ContextServerStatus, ContextServerStore},
20};
21use prompt_store::PromptBuilder;
22use regex::Regex;
23use rpc::AnyProtoClient;
24use std::sync::LazyLock;
25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
26use util::{ResultExt, TryFutureExt};
27
28pub(crate) fn init(client: &AnyProtoClient) {
29 client.add_entity_message_handler(ContextStore::handle_advertise_contexts);
30 client.add_entity_request_handler(ContextStore::handle_open_context);
31 client.add_entity_request_handler(ContextStore::handle_create_context);
32 client.add_entity_message_handler(ContextStore::handle_update_context);
33 client.add_entity_request_handler(ContextStore::handle_synchronize_contexts);
34}
35
36#[derive(Clone)]
37pub struct RemoteContextMetadata {
38 pub id: ContextId,
39 pub summary: Option<String>,
40}
41
42pub struct ContextStore {
43 contexts: Vec<ContextHandle>,
44 contexts_metadata: Vec<SavedContextMetadata>,
45 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
46 host_contexts: Vec<RemoteContextMetadata>,
47 fs: Arc<dyn Fs>,
48 languages: Arc<LanguageRegistry>,
49 slash_commands: Arc<SlashCommandWorkingSet>,
50 telemetry: Arc<Telemetry>,
51 _watch_updates: Task<Option<()>>,
52 client: Arc<Client>,
53 project: Entity<Project>,
54 project_is_shared: bool,
55 client_subscription: Option<client::Subscription>,
56 _project_subscriptions: Vec<gpui::Subscription>,
57 prompt_builder: Arc<PromptBuilder>,
58}
59
60pub enum ContextStoreEvent {
61 ContextCreated(ContextId),
62}
63
64impl EventEmitter<ContextStoreEvent> for ContextStore {}
65
66enum ContextHandle {
67 Weak(WeakEntity<AssistantContext>),
68 Strong(Entity<AssistantContext>),
69}
70
71impl ContextHandle {
72 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
73 match self {
74 ContextHandle::Weak(weak) => weak.upgrade(),
75 ContextHandle::Strong(strong) => Some(strong.clone()),
76 }
77 }
78
79 fn downgrade(&self) -> WeakEntity<AssistantContext> {
80 match self {
81 ContextHandle::Weak(weak) => weak.clone(),
82 ContextHandle::Strong(strong) => strong.downgrade(),
83 }
84 }
85}
86
87impl ContextStore {
88 pub fn new(
89 project: Entity<Project>,
90 prompt_builder: Arc<PromptBuilder>,
91 slash_commands: Arc<SlashCommandWorkingSet>,
92 cx: &mut App,
93 ) -> Task<Result<Entity<Self>>> {
94 let fs = project.read(cx).fs().clone();
95 let languages = project.read(cx).languages().clone();
96 let telemetry = project.read(cx).client().telemetry().clone();
97 cx.spawn(async move |cx| {
98 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
99 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
100
101 let this = cx.new(|cx: &mut Context<Self>| {
102 let mut this = Self {
103 contexts: Vec::new(),
104 contexts_metadata: Vec::new(),
105 context_server_slash_command_ids: HashMap::default(),
106 host_contexts: Vec::new(),
107 fs,
108 languages,
109 slash_commands,
110 telemetry,
111 _watch_updates: cx.spawn(async move |this, cx| {
112 async move {
113 while events.next().await.is_some() {
114 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
115 }
116 anyhow::Ok(())
117 }
118 .log_err()
119 .await
120 }),
121 client_subscription: None,
122 _project_subscriptions: vec![
123 cx.subscribe(&project, Self::handle_project_event),
124 ],
125 project_is_shared: false,
126 client: project.read(cx).client(),
127 project: project.clone(),
128 prompt_builder,
129 };
130 this.handle_project_shared(project.clone(), cx);
131 this.synchronize_contexts(cx);
132 this.register_context_server_handlers(cx);
133 this.reload(cx).detach_and_log_err(cx);
134 this
135 })?;
136
137 Ok(this)
138 })
139 }
140
141 async fn handle_advertise_contexts(
142 this: Entity<Self>,
143 envelope: TypedEnvelope<proto::AdvertiseContexts>,
144 mut cx: AsyncApp,
145 ) -> Result<()> {
146 this.update(&mut cx, |this, cx| {
147 this.host_contexts = envelope
148 .payload
149 .contexts
150 .into_iter()
151 .map(|context| RemoteContextMetadata {
152 id: ContextId::from_proto(context.context_id),
153 summary: context.summary,
154 })
155 .collect();
156 cx.notify();
157 })
158 }
159
160 async fn handle_open_context(
161 this: Entity<Self>,
162 envelope: TypedEnvelope<proto::OpenContext>,
163 mut cx: AsyncApp,
164 ) -> Result<proto::OpenContextResponse> {
165 let context_id = ContextId::from_proto(envelope.payload.context_id);
166 let operations = this.update(&mut cx, |this, cx| {
167 anyhow::ensure!(
168 !this.project.read(cx).is_via_collab(),
169 "only the host contexts can be opened"
170 );
171
172 let context = this
173 .loaded_context_for_id(&context_id, cx)
174 .context("context not found")?;
175 anyhow::ensure!(
176 context.read(cx).replica_id() == ReplicaId::default(),
177 "context must be opened via the host"
178 );
179
180 anyhow::Ok(
181 context
182 .read(cx)
183 .serialize_ops(&ContextVersion::default(), cx),
184 )
185 })??;
186 let operations = operations.await;
187 Ok(proto::OpenContextResponse {
188 context: Some(proto::Context { operations }),
189 })
190 }
191
192 async fn handle_create_context(
193 this: Entity<Self>,
194 _: TypedEnvelope<proto::CreateContext>,
195 mut cx: AsyncApp,
196 ) -> Result<proto::CreateContextResponse> {
197 let (context_id, operations) = this.update(&mut cx, |this, cx| {
198 anyhow::ensure!(
199 !this.project.read(cx).is_via_collab(),
200 "can only create contexts as the host"
201 );
202
203 let context = this.create(cx);
204 let context_id = context.read(cx).id().clone();
205 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
206
207 anyhow::Ok((
208 context_id,
209 context
210 .read(cx)
211 .serialize_ops(&ContextVersion::default(), cx),
212 ))
213 })??;
214 let operations = operations.await;
215 Ok(proto::CreateContextResponse {
216 context_id: context_id.to_proto(),
217 context: Some(proto::Context { operations }),
218 })
219 }
220
221 async fn handle_update_context(
222 this: Entity<Self>,
223 envelope: TypedEnvelope<proto::UpdateContext>,
224 mut cx: AsyncApp,
225 ) -> Result<()> {
226 this.update(&mut cx, |this, cx| {
227 let context_id = ContextId::from_proto(envelope.payload.context_id);
228 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
229 let operation_proto = envelope.payload.operation.context("invalid operation")?;
230 let operation = ContextOperation::from_proto(operation_proto)?;
231 context.update(cx, |context, cx| context.apply_ops([operation], cx));
232 }
233 Ok(())
234 })?
235 }
236
237 async fn handle_synchronize_contexts(
238 this: Entity<Self>,
239 envelope: TypedEnvelope<proto::SynchronizeContexts>,
240 mut cx: AsyncApp,
241 ) -> Result<proto::SynchronizeContextsResponse> {
242 this.update(&mut cx, |this, cx| {
243 anyhow::ensure!(
244 !this.project.read(cx).is_via_collab(),
245 "only the host can synchronize contexts"
246 );
247
248 let mut local_versions = Vec::new();
249 for remote_version_proto in envelope.payload.contexts {
250 let remote_version = ContextVersion::from_proto(&remote_version_proto);
251 let context_id = ContextId::from_proto(remote_version_proto.context_id);
252 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
253 let context = context.read(cx);
254 let operations = context.serialize_ops(&remote_version, cx);
255 local_versions.push(context.version(cx).to_proto(context_id.clone()));
256 let client = this.client.clone();
257 let project_id = envelope.payload.project_id;
258 cx.background_spawn(async move {
259 let operations = operations.await;
260 for operation in operations {
261 client.send(proto::UpdateContext {
262 project_id,
263 context_id: context_id.to_proto(),
264 operation: Some(operation),
265 })?;
266 }
267 anyhow::Ok(())
268 })
269 .detach_and_log_err(cx);
270 }
271 }
272
273 this.advertise_contexts(cx);
274
275 anyhow::Ok(proto::SynchronizeContextsResponse {
276 contexts: local_versions,
277 })
278 })?
279 }
280
281 fn handle_project_shared(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
282 let is_shared = self.project.read(cx).is_shared();
283 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
284 if is_shared == was_shared {
285 return;
286 }
287
288 if is_shared {
289 self.contexts.retain_mut(|context| {
290 if let Some(strong_context) = context.upgrade() {
291 *context = ContextHandle::Strong(strong_context);
292 true
293 } else {
294 false
295 }
296 });
297 let remote_id = self.project.read(cx).remote_id().unwrap();
298 self.client_subscription = self
299 .client
300 .subscribe_to_entity(remote_id)
301 .log_err()
302 .map(|subscription| subscription.set_entity(&cx.entity(), &mut cx.to_async()));
303 self.advertise_contexts(cx);
304 } else {
305 self.client_subscription = None;
306 }
307 }
308
309 fn handle_project_event(
310 &mut self,
311 project: Entity<Project>,
312 event: &project::Event,
313 cx: &mut Context<Self>,
314 ) {
315 match event {
316 project::Event::RemoteIdChanged(_) => {
317 self.handle_project_shared(project, cx);
318 }
319 project::Event::Reshared => {
320 self.advertise_contexts(cx);
321 }
322 project::Event::HostReshared | project::Event::Rejoined => {
323 self.synchronize_contexts(cx);
324 }
325 project::Event::DisconnectedFromHost => {
326 self.contexts.retain_mut(|context| {
327 if let Some(strong_context) = context.upgrade() {
328 *context = ContextHandle::Weak(context.downgrade());
329 strong_context.update(cx, |context, cx| {
330 if context.replica_id() != ReplicaId::default() {
331 context.set_capability(language::Capability::ReadOnly, cx);
332 }
333 });
334 true
335 } else {
336 false
337 }
338 });
339 self.host_contexts.clear();
340 cx.notify();
341 }
342 _ => {}
343 }
344 }
345
346 pub fn unordered_contexts(&self) -> impl Iterator<Item = &SavedContextMetadata> {
347 self.contexts_metadata.iter()
348 }
349
350 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
351 let context = cx.new(|cx| {
352 AssistantContext::local(
353 self.languages.clone(),
354 Some(self.project.clone()),
355 Some(self.telemetry.clone()),
356 self.prompt_builder.clone(),
357 self.slash_commands.clone(),
358 cx,
359 )
360 });
361 self.register_context(&context, cx);
362 context
363 }
364
365 pub fn create_remote_context(
366 &mut self,
367 cx: &mut Context<Self>,
368 ) -> Task<Result<Entity<AssistantContext>>> {
369 let project = self.project.read(cx);
370 let Some(project_id) = project.remote_id() else {
371 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
372 };
373
374 let replica_id = project.replica_id();
375 let capability = project.capability();
376 let language_registry = self.languages.clone();
377 let project = self.project.clone();
378 let telemetry = self.telemetry.clone();
379 let prompt_builder = self.prompt_builder.clone();
380 let slash_commands = self.slash_commands.clone();
381 let request = self.client.request(proto::CreateContext { project_id });
382 cx.spawn(async move |this, cx| {
383 let response = request.await?;
384 let context_id = ContextId::from_proto(response.context_id);
385 let context_proto = response.context.context("invalid context")?;
386 let context = cx.new(|cx| {
387 AssistantContext::new(
388 context_id.clone(),
389 replica_id,
390 capability,
391 language_registry,
392 prompt_builder,
393 slash_commands,
394 Some(project),
395 Some(telemetry),
396 cx,
397 )
398 })?;
399 let operations = cx
400 .background_spawn(async move {
401 context_proto
402 .operations
403 .into_iter()
404 .map(ContextOperation::from_proto)
405 .collect::<Result<Vec<_>>>()
406 })
407 .await?;
408 context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
409 this.update(cx, |this, cx| {
410 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
411 existing_context
412 } else {
413 this.register_context(&context, cx);
414 this.synchronize_contexts(cx);
415 context
416 }
417 })
418 })
419 }
420
421 pub fn open_local_context(
422 &mut self,
423 path: Arc<Path>,
424 cx: &Context<Self>,
425 ) -> Task<Result<Entity<AssistantContext>>> {
426 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
427 return Task::ready(Ok(existing_context));
428 }
429
430 let fs = self.fs.clone();
431 let languages = self.languages.clone();
432 let project = self.project.clone();
433 let telemetry = self.telemetry.clone();
434 let load = cx.background_spawn({
435 let path = path.clone();
436 async move {
437 let saved_context = fs.load(&path).await?;
438 SavedContext::from_json(&saved_context)
439 }
440 });
441 let prompt_builder = self.prompt_builder.clone();
442 let slash_commands = self.slash_commands.clone();
443
444 cx.spawn(async move |this, cx| {
445 let saved_context = load.await?;
446 let context = cx.new(|cx| {
447 AssistantContext::deserialize(
448 saved_context,
449 path.clone(),
450 languages,
451 prompt_builder,
452 slash_commands,
453 Some(project),
454 Some(telemetry),
455 cx,
456 )
457 })?;
458 this.update(cx, |this, cx| {
459 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
460 existing_context
461 } else {
462 this.register_context(&context, cx);
463 context
464 }
465 })
466 })
467 }
468
469 pub fn delete_local_context(
470 &mut self,
471 path: Arc<Path>,
472 cx: &mut Context<Self>,
473 ) -> Task<Result<()>> {
474 let fs = self.fs.clone();
475
476 cx.spawn(async move |this, cx| {
477 fs.remove_file(
478 &path,
479 RemoveOptions {
480 recursive: false,
481 ignore_if_not_exists: true,
482 },
483 )
484 .await?;
485
486 this.update(cx, |this, cx| {
487 this.contexts.retain(|context| {
488 context
489 .upgrade()
490 .and_then(|context| context.read(cx).path())
491 != Some(&path)
492 });
493 this.contexts_metadata
494 .retain(|context| context.path.as_ref() != path.as_ref());
495 })?;
496
497 Ok(())
498 })
499 }
500
501 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
502 self.contexts.iter().find_map(|context| {
503 let context = context.upgrade()?;
504 if context.read(cx).path().map(Arc::as_ref) == Some(path) {
505 Some(context)
506 } else {
507 None
508 }
509 })
510 }
511
512 pub fn loaded_context_for_id(
513 &self,
514 id: &ContextId,
515 cx: &App,
516 ) -> Option<Entity<AssistantContext>> {
517 self.contexts.iter().find_map(|context| {
518 let context = context.upgrade()?;
519 if context.read(cx).id() == id {
520 Some(context)
521 } else {
522 None
523 }
524 })
525 }
526
527 pub fn open_remote_context(
528 &mut self,
529 context_id: ContextId,
530 cx: &mut Context<Self>,
531 ) -> Task<Result<Entity<AssistantContext>>> {
532 let project = self.project.read(cx);
533 let Some(project_id) = project.remote_id() else {
534 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
535 };
536
537 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
538 return Task::ready(Ok(context));
539 }
540
541 let replica_id = project.replica_id();
542 let capability = project.capability();
543 let language_registry = self.languages.clone();
544 let project = self.project.clone();
545 let telemetry = self.telemetry.clone();
546 let request = self.client.request(proto::OpenContext {
547 project_id,
548 context_id: context_id.to_proto(),
549 });
550 let prompt_builder = self.prompt_builder.clone();
551 let slash_commands = self.slash_commands.clone();
552 cx.spawn(async move |this, cx| {
553 let response = request.await?;
554 let context_proto = response.context.context("invalid context")?;
555 let context = cx.new(|cx| {
556 AssistantContext::new(
557 context_id.clone(),
558 replica_id,
559 capability,
560 language_registry,
561 prompt_builder,
562 slash_commands,
563 Some(project),
564 Some(telemetry),
565 cx,
566 )
567 })?;
568 let operations = cx
569 .background_spawn(async move {
570 context_proto
571 .operations
572 .into_iter()
573 .map(ContextOperation::from_proto)
574 .collect::<Result<Vec<_>>>()
575 })
576 .await?;
577 context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
578 this.update(cx, |this, cx| {
579 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
580 existing_context
581 } else {
582 this.register_context(&context, cx);
583 this.synchronize_contexts(cx);
584 context
585 }
586 })
587 })
588 }
589
590 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
591 let handle = if self.project_is_shared {
592 ContextHandle::Strong(context.clone())
593 } else {
594 ContextHandle::Weak(context.downgrade())
595 };
596 self.contexts.push(handle);
597 self.advertise_contexts(cx);
598 cx.subscribe(context, Self::handle_context_event).detach();
599 }
600
601 fn handle_context_event(
602 &mut self,
603 context: Entity<AssistantContext>,
604 event: &ContextEvent,
605 cx: &mut Context<Self>,
606 ) {
607 let Some(project_id) = self.project.read(cx).remote_id() else {
608 return;
609 };
610
611 match event {
612 ContextEvent::SummaryChanged => {
613 self.advertise_contexts(cx);
614 }
615 ContextEvent::PathChanged { old_path, new_path } => {
616 if let Some(old_path) = old_path.as_ref() {
617 for metadata in &mut self.contexts_metadata {
618 if &metadata.path == old_path {
619 metadata.path = new_path.clone();
620 break;
621 }
622 }
623 }
624 }
625 ContextEvent::Operation(operation) => {
626 let context_id = context.read(cx).id().to_proto();
627 let operation = operation.to_proto();
628 self.client
629 .send(proto::UpdateContext {
630 project_id,
631 context_id,
632 operation: Some(operation),
633 })
634 .log_err();
635 }
636 _ => {}
637 }
638 }
639
640 fn advertise_contexts(&self, cx: &App) {
641 let Some(project_id) = self.project.read(cx).remote_id() else {
642 return;
643 };
644
645 // For now, only the host can advertise their open contexts.
646 if self.project.read(cx).is_via_collab() {
647 return;
648 }
649
650 let contexts = self
651 .contexts
652 .iter()
653 .rev()
654 .filter_map(|context| {
655 let context = context.upgrade()?.read(cx);
656 if context.replica_id() == ReplicaId::default() {
657 Some(proto::ContextMetadata {
658 context_id: context.id().to_proto(),
659 summary: context
660 .summary()
661 .content()
662 .map(|summary| summary.text.clone()),
663 })
664 } else {
665 None
666 }
667 })
668 .collect();
669 self.client
670 .send(proto::AdvertiseContexts {
671 project_id,
672 contexts,
673 })
674 .ok();
675 }
676
677 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
678 let Some(project_id) = self.project.read(cx).remote_id() else {
679 return;
680 };
681
682 let contexts = self
683 .contexts
684 .iter()
685 .filter_map(|context| {
686 let context = context.upgrade()?.read(cx);
687 if context.replica_id() != ReplicaId::default() {
688 Some(context.version(cx).to_proto(context.id().clone()))
689 } else {
690 None
691 }
692 })
693 .collect();
694
695 let client = self.client.clone();
696 let request = self.client.request(proto::SynchronizeContexts {
697 project_id,
698 contexts,
699 });
700 cx.spawn(async move |this, cx| {
701 let response = request.await?;
702
703 let mut context_ids = Vec::new();
704 let mut operations = Vec::new();
705 this.read_with(cx, |this, cx| {
706 for context_version_proto in response.contexts {
707 let context_version = ContextVersion::from_proto(&context_version_proto);
708 let context_id = ContextId::from_proto(context_version_proto.context_id);
709 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
710 context_ids.push(context_id);
711 operations.push(context.read(cx).serialize_ops(&context_version, cx));
712 }
713 }
714 })?;
715
716 let operations = futures::future::join_all(operations).await;
717 for (context_id, operations) in context_ids.into_iter().zip(operations) {
718 for operation in operations {
719 client.send(proto::UpdateContext {
720 project_id,
721 context_id: context_id.to_proto(),
722 operation: Some(operation),
723 })?;
724 }
725 }
726
727 anyhow::Ok(())
728 })
729 .detach_and_log_err(cx);
730 }
731
732 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
733 let metadata = self.contexts_metadata.clone();
734 let executor = cx.background_executor().clone();
735 cx.background_spawn(async move {
736 if query.is_empty() {
737 metadata
738 } else {
739 let candidates = metadata
740 .iter()
741 .enumerate()
742 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
743 .collect::<Vec<_>>();
744 let matches = fuzzy::match_strings(
745 &candidates,
746 &query,
747 false,
748 true,
749 100,
750 &Default::default(),
751 executor,
752 )
753 .await;
754
755 matches
756 .into_iter()
757 .map(|mat| metadata[mat.candidate_id].clone())
758 .collect()
759 }
760 })
761 }
762
763 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
764 &self.host_contexts
765 }
766
767 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
768 let fs = self.fs.clone();
769 cx.spawn(async move |this, cx| {
770 pub static ZED_STATELESS: LazyLock<bool> =
771 LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
772 if *ZED_STATELESS {
773 return Ok(());
774 }
775 fs.create_dir(contexts_dir()).await?;
776
777 let mut paths = fs.read_dir(contexts_dir()).await?;
778 let mut contexts = Vec::<SavedContextMetadata>::new();
779 while let Some(path) = paths.next().await {
780 let path = path?;
781 if path.extension() != Some(OsStr::new("json")) {
782 continue;
783 }
784
785 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
786 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
787
788 let metadata = fs.metadata(&path).await?;
789 if let Some((file_name, metadata)) = path
790 .file_name()
791 .and_then(|name| name.to_str())
792 .zip(metadata)
793 {
794 // This is used to filter out contexts saved by the new assistant.
795 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
796 continue;
797 }
798
799 if let Some(title) = ASSISTANT_CONTEXT_REGEX
800 .replace(file_name, "")
801 .lines()
802 .next()
803 {
804 contexts.push(SavedContextMetadata {
805 title: title.to_string().into(),
806 path: path.into(),
807 mtime: metadata.mtime.timestamp_for_user().into(),
808 });
809 }
810 }
811 }
812 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
813
814 this.update(cx, |this, cx| {
815 this.contexts_metadata = contexts;
816 cx.notify();
817 })
818 })
819 }
820
821 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
822 let context_server_store = self.project.read(cx).context_server_store();
823 cx.subscribe(&context_server_store, Self::handle_context_server_event)
824 .detach();
825
826 // Check for any servers that were already running before the handler was registered
827 for server in context_server_store.read(cx).running_servers() {
828 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
829 }
830 }
831
832 fn handle_context_server_event(
833 &mut self,
834 context_server_store: Entity<ContextServerStore>,
835 event: &project::context_server_store::Event,
836 cx: &mut Context<Self>,
837 ) {
838 match event {
839 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
840 match status {
841 ContextServerStatus::Running => {
842 self.load_context_server_slash_commands(
843 server_id.clone(),
844 context_server_store.clone(),
845 cx,
846 );
847 }
848 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
849 if let Some(slash_command_ids) =
850 self.context_server_slash_command_ids.remove(server_id)
851 {
852 self.slash_commands.remove(&slash_command_ids);
853 }
854 }
855 _ => {}
856 }
857 }
858 }
859 }
860
861 fn load_context_server_slash_commands(
862 &self,
863 server_id: ContextServerId,
864 context_server_store: Entity<ContextServerStore>,
865 cx: &mut Context<Self>,
866 ) {
867 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
868 return;
869 };
870 let slash_command_working_set = self.slash_commands.clone();
871 cx.spawn(async move |this, cx| {
872 let Some(protocol) = server.client() else {
873 return;
874 };
875
876 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
877 if let Some(response) = protocol
878 .request::<context_server::types::requests::PromptsList>(())
879 .await
880 .log_err()
881 {
882 let slash_command_ids = response
883 .prompts
884 .into_iter()
885 .filter(assistant_slash_commands::acceptable_prompt)
886 .map(|prompt| {
887 log::info!("registering context server command: {:?}", prompt.name);
888 slash_command_working_set.insert(Arc::new(
889 assistant_slash_commands::ContextServerSlashCommand::new(
890 context_server_store.clone(),
891 server.id(),
892 prompt,
893 ),
894 ))
895 })
896 .collect::<Vec<_>>();
897
898 this.update(cx, |this, _cx| {
899 this.context_server_slash_command_ids
900 .insert(server_id.clone(), slash_command_ids);
901 })
902 .log_err();
903 }
904 }
905 })
906 .detach();
907 }
908}