1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
15use language::LanguageRegistry;
16use paths::contexts_dir;
17use project::{
18 Project,
19 context_server_store::{ContextServerStatus, ContextServerStore},
20};
21use prompt_store::PromptBuilder;
22use regex::Regex;
23use rpc::AnyProtoClient;
24use std::sync::LazyLock;
25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
26use util::{ResultExt, TryFutureExt};
27
28pub(crate) fn init(client: &AnyProtoClient) {
29 client.add_entity_message_handler(ContextStore::handle_advertise_contexts);
30 client.add_entity_request_handler(ContextStore::handle_open_context);
31 client.add_entity_request_handler(ContextStore::handle_create_context);
32 client.add_entity_message_handler(ContextStore::handle_update_context);
33 client.add_entity_request_handler(ContextStore::handle_synchronize_contexts);
34}
35
36#[derive(Clone)]
37pub struct RemoteContextMetadata {
38 pub id: ContextId,
39 pub summary: Option<String>,
40}
41
42pub struct ContextStore {
43 contexts: Vec<ContextHandle>,
44 contexts_metadata: Vec<SavedContextMetadata>,
45 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
46 host_contexts: Vec<RemoteContextMetadata>,
47 fs: Arc<dyn Fs>,
48 languages: Arc<LanguageRegistry>,
49 slash_commands: Arc<SlashCommandWorkingSet>,
50 telemetry: Arc<Telemetry>,
51 _watch_updates: Task<Option<()>>,
52 client: Arc<Client>,
53 project: Entity<Project>,
54 project_is_shared: bool,
55 client_subscription: Option<client::Subscription>,
56 _project_subscriptions: Vec<gpui::Subscription>,
57 prompt_builder: Arc<PromptBuilder>,
58}
59
60pub enum ContextStoreEvent {
61 ContextCreated(ContextId),
62}
63
64impl EventEmitter<ContextStoreEvent> for ContextStore {}
65
66enum ContextHandle {
67 Weak(WeakEntity<AssistantContext>),
68 Strong(Entity<AssistantContext>),
69}
70
71impl ContextHandle {
72 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
73 match self {
74 ContextHandle::Weak(weak) => weak.upgrade(),
75 ContextHandle::Strong(strong) => Some(strong.clone()),
76 }
77 }
78
79 fn downgrade(&self) -> WeakEntity<AssistantContext> {
80 match self {
81 ContextHandle::Weak(weak) => weak.clone(),
82 ContextHandle::Strong(strong) => strong.downgrade(),
83 }
84 }
85}
86
87impl ContextStore {
88 pub fn new(
89 project: Entity<Project>,
90 prompt_builder: Arc<PromptBuilder>,
91 slash_commands: Arc<SlashCommandWorkingSet>,
92 cx: &mut App,
93 ) -> Task<Result<Entity<Self>>> {
94 let fs = project.read(cx).fs().clone();
95 let languages = project.read(cx).languages().clone();
96 let telemetry = project.read(cx).client().telemetry().clone();
97 cx.spawn(async move |cx| {
98 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
99 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
100
101 let this = cx.new(|cx: &mut Context<Self>| {
102 let mut this = Self {
103 contexts: Vec::new(),
104 contexts_metadata: Vec::new(),
105 context_server_slash_command_ids: HashMap::default(),
106 host_contexts: Vec::new(),
107 fs,
108 languages,
109 slash_commands,
110 telemetry,
111 _watch_updates: cx.spawn(async move |this, cx| {
112 async move {
113 while events.next().await.is_some() {
114 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
115 }
116 anyhow::Ok(())
117 }
118 .log_err()
119 .await
120 }),
121 client_subscription: None,
122 _project_subscriptions: vec![
123 cx.subscribe(&project, Self::handle_project_event),
124 ],
125 project_is_shared: false,
126 client: project.read(cx).client(),
127 project: project.clone(),
128 prompt_builder,
129 };
130 this.handle_project_shared(project.clone(), cx);
131 this.synchronize_contexts(cx);
132 this.register_context_server_handlers(cx);
133 this.reload(cx).detach_and_log_err(cx);
134 this
135 })?;
136
137 Ok(this)
138 })
139 }
140
141 #[cfg(any(test, feature = "test-support"))]
142 pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
143 Self {
144 contexts: Default::default(),
145 contexts_metadata: Default::default(),
146 context_server_slash_command_ids: Default::default(),
147 host_contexts: Default::default(),
148 fs: project.read(cx).fs().clone(),
149 languages: project.read(cx).languages().clone(),
150 slash_commands: Arc::default(),
151 telemetry: project.read(cx).client().telemetry().clone(),
152 _watch_updates: Task::ready(None),
153 client: project.read(cx).client(),
154 project,
155 project_is_shared: false,
156 client_subscription: None,
157 _project_subscriptions: Default::default(),
158 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
159 }
160 }
161
162 async fn handle_advertise_contexts(
163 this: Entity<Self>,
164 envelope: TypedEnvelope<proto::AdvertiseContexts>,
165 mut cx: AsyncApp,
166 ) -> Result<()> {
167 this.update(&mut cx, |this, cx| {
168 this.host_contexts = envelope
169 .payload
170 .contexts
171 .into_iter()
172 .map(|context| RemoteContextMetadata {
173 id: ContextId::from_proto(context.context_id),
174 summary: context.summary,
175 })
176 .collect();
177 cx.notify();
178 })
179 }
180
181 async fn handle_open_context(
182 this: Entity<Self>,
183 envelope: TypedEnvelope<proto::OpenContext>,
184 mut cx: AsyncApp,
185 ) -> Result<proto::OpenContextResponse> {
186 let context_id = ContextId::from_proto(envelope.payload.context_id);
187 let operations = this.update(&mut cx, |this, cx| {
188 anyhow::ensure!(
189 !this.project.read(cx).is_via_collab(),
190 "only the host contexts can be opened"
191 );
192
193 let context = this
194 .loaded_context_for_id(&context_id, cx)
195 .context("context not found")?;
196 anyhow::ensure!(
197 context.read(cx).replica_id() == ReplicaId::default(),
198 "context must be opened via the host"
199 );
200
201 anyhow::Ok(
202 context
203 .read(cx)
204 .serialize_ops(&ContextVersion::default(), cx),
205 )
206 })??;
207 let operations = operations.await;
208 Ok(proto::OpenContextResponse {
209 context: Some(proto::Context { operations }),
210 })
211 }
212
213 async fn handle_create_context(
214 this: Entity<Self>,
215 _: TypedEnvelope<proto::CreateContext>,
216 mut cx: AsyncApp,
217 ) -> Result<proto::CreateContextResponse> {
218 let (context_id, operations) = this.update(&mut cx, |this, cx| {
219 anyhow::ensure!(
220 !this.project.read(cx).is_via_collab(),
221 "can only create contexts as the host"
222 );
223
224 let context = this.create(cx);
225 let context_id = context.read(cx).id().clone();
226 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
227
228 anyhow::Ok((
229 context_id,
230 context
231 .read(cx)
232 .serialize_ops(&ContextVersion::default(), cx),
233 ))
234 })??;
235 let operations = operations.await;
236 Ok(proto::CreateContextResponse {
237 context_id: context_id.to_proto(),
238 context: Some(proto::Context { operations }),
239 })
240 }
241
242 async fn handle_update_context(
243 this: Entity<Self>,
244 envelope: TypedEnvelope<proto::UpdateContext>,
245 mut cx: AsyncApp,
246 ) -> Result<()> {
247 this.update(&mut cx, |this, cx| {
248 let context_id = ContextId::from_proto(envelope.payload.context_id);
249 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
250 let operation_proto = envelope.payload.operation.context("invalid operation")?;
251 let operation = ContextOperation::from_proto(operation_proto)?;
252 context.update(cx, |context, cx| context.apply_ops([operation], cx));
253 }
254 Ok(())
255 })?
256 }
257
258 async fn handle_synchronize_contexts(
259 this: Entity<Self>,
260 envelope: TypedEnvelope<proto::SynchronizeContexts>,
261 mut cx: AsyncApp,
262 ) -> Result<proto::SynchronizeContextsResponse> {
263 this.update(&mut cx, |this, cx| {
264 anyhow::ensure!(
265 !this.project.read(cx).is_via_collab(),
266 "only the host can synchronize contexts"
267 );
268
269 let mut local_versions = Vec::new();
270 for remote_version_proto in envelope.payload.contexts {
271 let remote_version = ContextVersion::from_proto(&remote_version_proto);
272 let context_id = ContextId::from_proto(remote_version_proto.context_id);
273 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
274 let context = context.read(cx);
275 let operations = context.serialize_ops(&remote_version, cx);
276 local_versions.push(context.version(cx).to_proto(context_id.clone()));
277 let client = this.client.clone();
278 let project_id = envelope.payload.project_id;
279 cx.background_spawn(async move {
280 let operations = operations.await;
281 for operation in operations {
282 client.send(proto::UpdateContext {
283 project_id,
284 context_id: context_id.to_proto(),
285 operation: Some(operation),
286 })?;
287 }
288 anyhow::Ok(())
289 })
290 .detach_and_log_err(cx);
291 }
292 }
293
294 this.advertise_contexts(cx);
295
296 anyhow::Ok(proto::SynchronizeContextsResponse {
297 contexts: local_versions,
298 })
299 })?
300 }
301
302 fn handle_project_shared(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
303 let is_shared = self.project.read(cx).is_shared();
304 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
305 if is_shared == was_shared {
306 return;
307 }
308
309 if is_shared {
310 self.contexts.retain_mut(|context| {
311 if let Some(strong_context) = context.upgrade() {
312 *context = ContextHandle::Strong(strong_context);
313 true
314 } else {
315 false
316 }
317 });
318 let remote_id = self.project.read(cx).remote_id().unwrap();
319 self.client_subscription = self
320 .client
321 .subscribe_to_entity(remote_id)
322 .log_err()
323 .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
324 self.advertise_contexts(cx);
325 } else {
326 self.client_subscription = None;
327 }
328 }
329
330 fn handle_project_event(
331 &mut self,
332 project: Entity<Project>,
333 event: &project::Event,
334 cx: &mut Context<Self>,
335 ) {
336 match event {
337 project::Event::RemoteIdChanged(_) => {
338 self.handle_project_shared(project, cx);
339 }
340 project::Event::Reshared => {
341 self.advertise_contexts(cx);
342 }
343 project::Event::HostReshared | project::Event::Rejoined => {
344 self.synchronize_contexts(cx);
345 }
346 project::Event::DisconnectedFromHost => {
347 self.contexts.retain_mut(|context| {
348 if let Some(strong_context) = context.upgrade() {
349 *context = ContextHandle::Weak(context.downgrade());
350 strong_context.update(cx, |context, cx| {
351 if context.replica_id() != ReplicaId::default() {
352 context.set_capability(language::Capability::ReadOnly, cx);
353 }
354 });
355 true
356 } else {
357 false
358 }
359 });
360 self.host_contexts.clear();
361 cx.notify();
362 }
363 _ => {}
364 }
365 }
366
367 pub fn unordered_contexts(&self) -> impl Iterator<Item = &SavedContextMetadata> {
368 self.contexts_metadata.iter()
369 }
370
371 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
372 let context = cx.new(|cx| {
373 AssistantContext::local(
374 self.languages.clone(),
375 Some(self.project.clone()),
376 Some(self.telemetry.clone()),
377 self.prompt_builder.clone(),
378 self.slash_commands.clone(),
379 cx,
380 )
381 });
382 self.register_context(&context, cx);
383 context
384 }
385
386 pub fn create_remote_context(
387 &mut self,
388 cx: &mut Context<Self>,
389 ) -> Task<Result<Entity<AssistantContext>>> {
390 let project = self.project.read(cx);
391 let Some(project_id) = project.remote_id() else {
392 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
393 };
394
395 let replica_id = project.replica_id();
396 let capability = project.capability();
397 let language_registry = self.languages.clone();
398 let project = self.project.clone();
399 let telemetry = self.telemetry.clone();
400 let prompt_builder = self.prompt_builder.clone();
401 let slash_commands = self.slash_commands.clone();
402 let request = self.client.request(proto::CreateContext { project_id });
403 cx.spawn(async move |this, cx| {
404 let response = request.await?;
405 let context_id = ContextId::from_proto(response.context_id);
406 let context_proto = response.context.context("invalid context")?;
407 let context = cx.new(|cx| {
408 AssistantContext::new(
409 context_id.clone(),
410 replica_id,
411 capability,
412 language_registry,
413 prompt_builder,
414 slash_commands,
415 Some(project),
416 Some(telemetry),
417 cx,
418 )
419 })?;
420 let operations = cx
421 .background_spawn(async move {
422 context_proto
423 .operations
424 .into_iter()
425 .map(ContextOperation::from_proto)
426 .collect::<Result<Vec<_>>>()
427 })
428 .await?;
429 context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
430 this.update(cx, |this, cx| {
431 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
432 existing_context
433 } else {
434 this.register_context(&context, cx);
435 this.synchronize_contexts(cx);
436 context
437 }
438 })
439 })
440 }
441
442 pub fn open_local_context(
443 &mut self,
444 path: Arc<Path>,
445 cx: &Context<Self>,
446 ) -> Task<Result<Entity<AssistantContext>>> {
447 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
448 return Task::ready(Ok(existing_context));
449 }
450
451 let fs = self.fs.clone();
452 let languages = self.languages.clone();
453 let project = self.project.clone();
454 let telemetry = self.telemetry.clone();
455 let load = cx.background_spawn({
456 let path = path.clone();
457 async move {
458 let saved_context = fs.load(&path).await?;
459 SavedContext::from_json(&saved_context)
460 }
461 });
462 let prompt_builder = self.prompt_builder.clone();
463 let slash_commands = self.slash_commands.clone();
464
465 cx.spawn(async move |this, cx| {
466 let saved_context = load.await?;
467 let context = cx.new(|cx| {
468 AssistantContext::deserialize(
469 saved_context,
470 path.clone(),
471 languages,
472 prompt_builder,
473 slash_commands,
474 Some(project),
475 Some(telemetry),
476 cx,
477 )
478 })?;
479 this.update(cx, |this, cx| {
480 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
481 existing_context
482 } else {
483 this.register_context(&context, cx);
484 context
485 }
486 })
487 })
488 }
489
490 pub fn delete_local_context(
491 &mut self,
492 path: Arc<Path>,
493 cx: &mut Context<Self>,
494 ) -> Task<Result<()>> {
495 let fs = self.fs.clone();
496
497 cx.spawn(async move |this, cx| {
498 fs.remove_file(
499 &path,
500 RemoveOptions {
501 recursive: false,
502 ignore_if_not_exists: true,
503 },
504 )
505 .await?;
506
507 this.update(cx, |this, cx| {
508 this.contexts.retain(|context| {
509 context
510 .upgrade()
511 .and_then(|context| context.read(cx).path())
512 != Some(&path)
513 });
514 this.contexts_metadata
515 .retain(|context| context.path.as_ref() != path.as_ref());
516 })?;
517
518 Ok(())
519 })
520 }
521
522 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
523 self.contexts.iter().find_map(|context| {
524 let context = context.upgrade()?;
525 if context.read(cx).path().map(Arc::as_ref) == Some(path) {
526 Some(context)
527 } else {
528 None
529 }
530 })
531 }
532
533 pub fn loaded_context_for_id(
534 &self,
535 id: &ContextId,
536 cx: &App,
537 ) -> Option<Entity<AssistantContext>> {
538 self.contexts.iter().find_map(|context| {
539 let context = context.upgrade()?;
540 if context.read(cx).id() == id {
541 Some(context)
542 } else {
543 None
544 }
545 })
546 }
547
548 pub fn open_remote_context(
549 &mut self,
550 context_id: ContextId,
551 cx: &mut Context<Self>,
552 ) -> Task<Result<Entity<AssistantContext>>> {
553 let project = self.project.read(cx);
554 let Some(project_id) = project.remote_id() else {
555 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
556 };
557
558 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
559 return Task::ready(Ok(context));
560 }
561
562 let replica_id = project.replica_id();
563 let capability = project.capability();
564 let language_registry = self.languages.clone();
565 let project = self.project.clone();
566 let telemetry = self.telemetry.clone();
567 let request = self.client.request(proto::OpenContext {
568 project_id,
569 context_id: context_id.to_proto(),
570 });
571 let prompt_builder = self.prompt_builder.clone();
572 let slash_commands = self.slash_commands.clone();
573 cx.spawn(async move |this, cx| {
574 let response = request.await?;
575 let context_proto = response.context.context("invalid context")?;
576 let context = cx.new(|cx| {
577 AssistantContext::new(
578 context_id.clone(),
579 replica_id,
580 capability,
581 language_registry,
582 prompt_builder,
583 slash_commands,
584 Some(project),
585 Some(telemetry),
586 cx,
587 )
588 })?;
589 let operations = cx
590 .background_spawn(async move {
591 context_proto
592 .operations
593 .into_iter()
594 .map(ContextOperation::from_proto)
595 .collect::<Result<Vec<_>>>()
596 })
597 .await?;
598 context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
599 this.update(cx, |this, cx| {
600 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
601 existing_context
602 } else {
603 this.register_context(&context, cx);
604 this.synchronize_contexts(cx);
605 context
606 }
607 })
608 })
609 }
610
611 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
612 let handle = if self.project_is_shared {
613 ContextHandle::Strong(context.clone())
614 } else {
615 ContextHandle::Weak(context.downgrade())
616 };
617 self.contexts.push(handle);
618 self.advertise_contexts(cx);
619 cx.subscribe(context, Self::handle_context_event).detach();
620 }
621
622 fn handle_context_event(
623 &mut self,
624 context: Entity<AssistantContext>,
625 event: &ContextEvent,
626 cx: &mut Context<Self>,
627 ) {
628 let Some(project_id) = self.project.read(cx).remote_id() else {
629 return;
630 };
631
632 match event {
633 ContextEvent::SummaryChanged => {
634 self.advertise_contexts(cx);
635 }
636 ContextEvent::PathChanged { old_path, new_path } => {
637 if let Some(old_path) = old_path.as_ref() {
638 for metadata in &mut self.contexts_metadata {
639 if &metadata.path == old_path {
640 metadata.path = new_path.clone();
641 break;
642 }
643 }
644 }
645 }
646 ContextEvent::Operation(operation) => {
647 let context_id = context.read(cx).id().to_proto();
648 let operation = operation.to_proto();
649 self.client
650 .send(proto::UpdateContext {
651 project_id,
652 context_id,
653 operation: Some(operation),
654 })
655 .log_err();
656 }
657 _ => {}
658 }
659 }
660
661 fn advertise_contexts(&self, cx: &App) {
662 let Some(project_id) = self.project.read(cx).remote_id() else {
663 return;
664 };
665
666 // For now, only the host can advertise their open contexts.
667 if self.project.read(cx).is_via_collab() {
668 return;
669 }
670
671 let contexts = self
672 .contexts
673 .iter()
674 .rev()
675 .filter_map(|context| {
676 let context = context.upgrade()?.read(cx);
677 if context.replica_id() == ReplicaId::default() {
678 Some(proto::ContextMetadata {
679 context_id: context.id().to_proto(),
680 summary: context
681 .summary()
682 .content()
683 .map(|summary| summary.text.clone()),
684 })
685 } else {
686 None
687 }
688 })
689 .collect();
690 self.client
691 .send(proto::AdvertiseContexts {
692 project_id,
693 contexts,
694 })
695 .ok();
696 }
697
698 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
699 let Some(project_id) = self.project.read(cx).remote_id() else {
700 return;
701 };
702
703 let contexts = self
704 .contexts
705 .iter()
706 .filter_map(|context| {
707 let context = context.upgrade()?.read(cx);
708 if context.replica_id() != ReplicaId::default() {
709 Some(context.version(cx).to_proto(context.id().clone()))
710 } else {
711 None
712 }
713 })
714 .collect();
715
716 let client = self.client.clone();
717 let request = self.client.request(proto::SynchronizeContexts {
718 project_id,
719 contexts,
720 });
721 cx.spawn(async move |this, cx| {
722 let response = request.await?;
723
724 let mut context_ids = Vec::new();
725 let mut operations = Vec::new();
726 this.read_with(cx, |this, cx| {
727 for context_version_proto in response.contexts {
728 let context_version = ContextVersion::from_proto(&context_version_proto);
729 let context_id = ContextId::from_proto(context_version_proto.context_id);
730 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
731 context_ids.push(context_id);
732 operations.push(context.read(cx).serialize_ops(&context_version, cx));
733 }
734 }
735 })?;
736
737 let operations = futures::future::join_all(operations).await;
738 for (context_id, operations) in context_ids.into_iter().zip(operations) {
739 for operation in operations {
740 client.send(proto::UpdateContext {
741 project_id,
742 context_id: context_id.to_proto(),
743 operation: Some(operation),
744 })?;
745 }
746 }
747
748 anyhow::Ok(())
749 })
750 .detach_and_log_err(cx);
751 }
752
753 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
754 let metadata = self.contexts_metadata.clone();
755 let executor = cx.background_executor().clone();
756 cx.background_spawn(async move {
757 if query.is_empty() {
758 metadata
759 } else {
760 let candidates = metadata
761 .iter()
762 .enumerate()
763 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
764 .collect::<Vec<_>>();
765 let matches = fuzzy::match_strings(
766 &candidates,
767 &query,
768 false,
769 true,
770 100,
771 &Default::default(),
772 executor,
773 )
774 .await;
775
776 matches
777 .into_iter()
778 .map(|mat| metadata[mat.candidate_id].clone())
779 .collect()
780 }
781 })
782 }
783
784 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
785 &self.host_contexts
786 }
787
788 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
789 let fs = self.fs.clone();
790 cx.spawn(async move |this, cx| {
791 pub static ZED_STATELESS: LazyLock<bool> =
792 LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
793 if *ZED_STATELESS {
794 return Ok(());
795 }
796 fs.create_dir(contexts_dir()).await?;
797
798 let mut paths = fs.read_dir(contexts_dir()).await?;
799 let mut contexts = Vec::<SavedContextMetadata>::new();
800 while let Some(path) = paths.next().await {
801 let path = path?;
802 if path.extension() != Some(OsStr::new("json")) {
803 continue;
804 }
805
806 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
807 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
808
809 let metadata = fs.metadata(&path).await?;
810 if let Some((file_name, metadata)) = path
811 .file_name()
812 .and_then(|name| name.to_str())
813 .zip(metadata)
814 {
815 // This is used to filter out contexts saved by the new assistant.
816 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
817 continue;
818 }
819
820 if let Some(title) = ASSISTANT_CONTEXT_REGEX
821 .replace(file_name, "")
822 .lines()
823 .next()
824 {
825 contexts.push(SavedContextMetadata {
826 title: title.to_string().into(),
827 path: path.into(),
828 mtime: metadata.mtime.timestamp_for_user().into(),
829 });
830 }
831 }
832 }
833 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
834
835 this.update(cx, |this, cx| {
836 this.contexts_metadata = contexts;
837 cx.notify();
838 })
839 })
840 }
841
842 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
843 let context_server_store = self.project.read(cx).context_server_store();
844 cx.subscribe(&context_server_store, Self::handle_context_server_event)
845 .detach();
846
847 // Check for any servers that were already running before the handler was registered
848 for server in context_server_store.read(cx).running_servers() {
849 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
850 }
851 }
852
853 fn handle_context_server_event(
854 &mut self,
855 context_server_store: Entity<ContextServerStore>,
856 event: &project::context_server_store::Event,
857 cx: &mut Context<Self>,
858 ) {
859 match event {
860 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
861 match status {
862 ContextServerStatus::Running => {
863 self.load_context_server_slash_commands(
864 server_id.clone(),
865 context_server_store,
866 cx,
867 );
868 }
869 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
870 if let Some(slash_command_ids) =
871 self.context_server_slash_command_ids.remove(server_id)
872 {
873 self.slash_commands.remove(&slash_command_ids);
874 }
875 }
876 _ => {}
877 }
878 }
879 }
880 }
881
882 fn load_context_server_slash_commands(
883 &self,
884 server_id: ContextServerId,
885 context_server_store: Entity<ContextServerStore>,
886 cx: &mut Context<Self>,
887 ) {
888 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
889 return;
890 };
891 let slash_command_working_set = self.slash_commands.clone();
892 cx.spawn(async move |this, cx| {
893 let Some(protocol) = server.client() else {
894 return;
895 };
896
897 if protocol.capable(context_server::protocol::ServerCapability::Prompts)
898 && let Some(response) = protocol
899 .request::<context_server::types::requests::PromptsList>(())
900 .await
901 .log_err()
902 {
903 let slash_command_ids = response
904 .prompts
905 .into_iter()
906 .filter(assistant_slash_commands::acceptable_prompt)
907 .map(|prompt| {
908 log::info!("registering context server command: {:?}", prompt.name);
909 slash_command_working_set.insert(Arc::new(
910 assistant_slash_commands::ContextServerSlashCommand::new(
911 context_server_store.clone(),
912 server.id(),
913 prompt,
914 ),
915 ))
916 })
917 .collect::<Vec<_>>();
918
919 this.update(cx, |this, _cx| {
920 this.context_server_slash_command_ids
921 .insert(server_id.clone(), slash_command_ids);
922 })
923 .log_err();
924 }
925 })
926 .detach();
927 }
928}