1use crate::{
2 Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
7use clock::ReplicaId;
8use fs::Fs;
9use futures::StreamExt;
10use fuzzy::StringMatchCandidate;
11use gpui::{
12 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
13};
14use language::LanguageRegistry;
15use paths::contexts_dir;
16use project::Project;
17use regex::Regex;
18use std::{
19 cmp::Reverse,
20 ffi::OsStr,
21 mem,
22 path::{Path, PathBuf},
23 sync::Arc,
24 time::Duration,
25};
26use util::{ResultExt, TryFutureExt};
27
28pub fn init(client: &Arc<Client>) {
29 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
30 client.add_model_request_handler(ContextStore::handle_open_context);
31 client.add_model_request_handler(ContextStore::handle_create_context);
32 client.add_model_message_handler(ContextStore::handle_update_context);
33 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
34}
35
36#[derive(Clone)]
37pub struct RemoteContextMetadata {
38 pub id: ContextId,
39 pub summary: Option<String>,
40}
41
42pub struct ContextStore {
43 contexts: Vec<ContextHandle>,
44 contexts_metadata: Vec<SavedContextMetadata>,
45 host_contexts: Vec<RemoteContextMetadata>,
46 fs: Arc<dyn Fs>,
47 languages: Arc<LanguageRegistry>,
48 telemetry: Arc<Telemetry>,
49 _watch_updates: Task<Option<()>>,
50 client: Arc<Client>,
51 project: Model<Project>,
52 project_is_shared: bool,
53 client_subscription: Option<client::Subscription>,
54 _project_subscriptions: Vec<gpui::Subscription>,
55}
56
57pub enum ContextStoreEvent {
58 ContextCreated(ContextId),
59}
60
61impl EventEmitter<ContextStoreEvent> for ContextStore {}
62
63enum ContextHandle {
64 Weak(WeakModel<Context>),
65 Strong(Model<Context>),
66}
67
68impl ContextHandle {
69 fn upgrade(&self) -> Option<Model<Context>> {
70 match self {
71 ContextHandle::Weak(weak) => weak.upgrade(),
72 ContextHandle::Strong(strong) => Some(strong.clone()),
73 }
74 }
75
76 fn downgrade(&self) -> WeakModel<Context> {
77 match self {
78 ContextHandle::Weak(weak) => weak.clone(),
79 ContextHandle::Strong(strong) => strong.downgrade(),
80 }
81 }
82}
83
84impl ContextStore {
85 pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
86 let fs = project.read(cx).fs().clone();
87 let languages = project.read(cx).languages().clone();
88 let telemetry = project.read(cx).client().telemetry().clone();
89 cx.spawn(|mut cx| async move {
90 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
91 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
92
93 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
94 let mut this = Self {
95 contexts: Vec::new(),
96 contexts_metadata: Vec::new(),
97 host_contexts: Vec::new(),
98 fs,
99 languages,
100 telemetry,
101 _watch_updates: cx.spawn(|this, mut cx| {
102 async move {
103 while events.next().await.is_some() {
104 this.update(&mut cx, |this, cx| this.reload(cx))?
105 .await
106 .log_err();
107 }
108 anyhow::Ok(())
109 }
110 .log_err()
111 }),
112 client_subscription: None,
113 _project_subscriptions: vec![
114 cx.observe(&project, Self::handle_project_changed),
115 cx.subscribe(&project, Self::handle_project_event),
116 ],
117 project_is_shared: false,
118 client: project.read(cx).client(),
119 project: project.clone(),
120 };
121 this.handle_project_changed(project, cx);
122 this.synchronize_contexts(cx);
123 this
124 })?;
125 this.update(&mut cx, |this, cx| this.reload(cx))?
126 .await
127 .log_err();
128 Ok(this)
129 })
130 }
131
132 async fn handle_advertise_contexts(
133 this: Model<Self>,
134 envelope: TypedEnvelope<proto::AdvertiseContexts>,
135 mut cx: AsyncAppContext,
136 ) -> Result<()> {
137 this.update(&mut cx, |this, cx| {
138 this.host_contexts = envelope
139 .payload
140 .contexts
141 .into_iter()
142 .map(|context| RemoteContextMetadata {
143 id: ContextId::from_proto(context.context_id),
144 summary: context.summary,
145 })
146 .collect();
147 cx.notify();
148 })
149 }
150
151 async fn handle_open_context(
152 this: Model<Self>,
153 envelope: TypedEnvelope<proto::OpenContext>,
154 mut cx: AsyncAppContext,
155 ) -> Result<proto::OpenContextResponse> {
156 let context_id = ContextId::from_proto(envelope.payload.context_id);
157 let operations = this.update(&mut cx, |this, cx| {
158 if this.project.read(cx).is_remote() {
159 return Err(anyhow!("only the host contexts can be opened"));
160 }
161
162 let context = this
163 .loaded_context_for_id(&context_id, cx)
164 .context("context not found")?;
165 if context.read(cx).replica_id() != ReplicaId::default() {
166 return Err(anyhow!("context must be opened via the host"));
167 }
168
169 anyhow::Ok(
170 context
171 .read(cx)
172 .serialize_ops(&ContextVersion::default(), cx),
173 )
174 })??;
175 let operations = operations.await;
176 Ok(proto::OpenContextResponse {
177 context: Some(proto::Context { operations }),
178 })
179 }
180
181 async fn handle_create_context(
182 this: Model<Self>,
183 _: TypedEnvelope<proto::CreateContext>,
184 mut cx: AsyncAppContext,
185 ) -> Result<proto::CreateContextResponse> {
186 let (context_id, operations) = this.update(&mut cx, |this, cx| {
187 if this.project.read(cx).is_remote() {
188 return Err(anyhow!("can only create contexts as the host"));
189 }
190
191 let context = this.create(cx);
192 let context_id = context.read(cx).id().clone();
193 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
194
195 anyhow::Ok((
196 context_id,
197 context
198 .read(cx)
199 .serialize_ops(&ContextVersion::default(), cx),
200 ))
201 })??;
202 let operations = operations.await;
203 Ok(proto::CreateContextResponse {
204 context_id: context_id.to_proto(),
205 context: Some(proto::Context { operations }),
206 })
207 }
208
209 async fn handle_update_context(
210 this: Model<Self>,
211 envelope: TypedEnvelope<proto::UpdateContext>,
212 mut cx: AsyncAppContext,
213 ) -> Result<()> {
214 this.update(&mut cx, |this, cx| {
215 let context_id = ContextId::from_proto(envelope.payload.context_id);
216 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
217 let operation_proto = envelope.payload.operation.context("invalid operation")?;
218 let operation = ContextOperation::from_proto(operation_proto)?;
219 context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
220 }
221 Ok(())
222 })?
223 }
224
225 async fn handle_synchronize_contexts(
226 this: Model<Self>,
227 envelope: TypedEnvelope<proto::SynchronizeContexts>,
228 mut cx: AsyncAppContext,
229 ) -> Result<proto::SynchronizeContextsResponse> {
230 this.update(&mut cx, |this, cx| {
231 if this.project.read(cx).is_remote() {
232 return Err(anyhow!("only the host can synchronize contexts"));
233 }
234
235 let mut local_versions = Vec::new();
236 for remote_version_proto in envelope.payload.contexts {
237 let remote_version = ContextVersion::from_proto(&remote_version_proto);
238 let context_id = ContextId::from_proto(remote_version_proto.context_id);
239 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
240 let context = context.read(cx);
241 let operations = context.serialize_ops(&remote_version, cx);
242 local_versions.push(context.version(cx).to_proto(context_id.clone()));
243 let client = this.client.clone();
244 let project_id = envelope.payload.project_id;
245 cx.background_executor()
246 .spawn(async move {
247 let operations = operations.await;
248 for operation in operations {
249 client.send(proto::UpdateContext {
250 project_id,
251 context_id: context_id.to_proto(),
252 operation: Some(operation),
253 })?;
254 }
255 anyhow::Ok(())
256 })
257 .detach_and_log_err(cx);
258 }
259 }
260
261 this.advertise_contexts(cx);
262
263 anyhow::Ok(proto::SynchronizeContextsResponse {
264 contexts: local_versions,
265 })
266 })?
267 }
268
269 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
270 let is_shared = self.project.read(cx).is_shared();
271 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
272 if is_shared == was_shared {
273 return;
274 }
275
276 if is_shared {
277 self.contexts.retain_mut(|context| {
278 if let Some(strong_context) = context.upgrade() {
279 *context = ContextHandle::Strong(strong_context);
280 true
281 } else {
282 false
283 }
284 });
285 let remote_id = self.project.read(cx).remote_id().unwrap();
286 self.client_subscription = self
287 .client
288 .subscribe_to_entity(remote_id)
289 .log_err()
290 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
291 self.advertise_contexts(cx);
292 } else {
293 self.client_subscription = None;
294 }
295 }
296
297 fn handle_project_event(
298 &mut self,
299 _: Model<Project>,
300 event: &project::Event,
301 cx: &mut ModelContext<Self>,
302 ) {
303 match event {
304 project::Event::Reshared => {
305 self.advertise_contexts(cx);
306 }
307 project::Event::HostReshared | project::Event::Rejoined => {
308 self.synchronize_contexts(cx);
309 }
310 project::Event::DisconnectedFromHost => {
311 self.contexts.retain_mut(|context| {
312 if let Some(strong_context) = context.upgrade() {
313 *context = ContextHandle::Weak(context.downgrade());
314 strong_context.update(cx, |context, cx| {
315 if context.replica_id() != ReplicaId::default() {
316 context.set_capability(language::Capability::ReadOnly, cx);
317 }
318 });
319 true
320 } else {
321 false
322 }
323 });
324 self.host_contexts.clear();
325 cx.notify();
326 }
327 _ => {}
328 }
329 }
330
331 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
332 let context = cx.new_model(|cx| {
333 Context::local(
334 self.languages.clone(),
335 Some(self.project.clone()),
336 Some(self.telemetry.clone()),
337 cx,
338 )
339 });
340 self.register_context(&context, cx);
341 context
342 }
343
344 pub fn create_remote_context(
345 &mut self,
346 cx: &mut ModelContext<Self>,
347 ) -> Task<Result<Model<Context>>> {
348 let project = self.project.read(cx);
349 let Some(project_id) = project.remote_id() else {
350 return Task::ready(Err(anyhow!("project was not remote")));
351 };
352 if project.is_local() {
353 return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
354 }
355
356 let replica_id = project.replica_id();
357 let capability = project.capability();
358 let language_registry = self.languages.clone();
359 let project = self.project.clone();
360 let telemetry = self.telemetry.clone();
361 let request = self.client.request(proto::CreateContext { project_id });
362 cx.spawn(|this, mut cx| async move {
363 let response = request.await?;
364 let context_id = ContextId::from_proto(response.context_id);
365 let context_proto = response.context.context("invalid context")?;
366 let context = cx.new_model(|cx| {
367 Context::new(
368 context_id.clone(),
369 replica_id,
370 capability,
371 language_registry,
372 Some(project),
373 Some(telemetry),
374 cx,
375 )
376 })?;
377 let operations = cx
378 .background_executor()
379 .spawn(async move {
380 context_proto
381 .operations
382 .into_iter()
383 .map(|op| ContextOperation::from_proto(op))
384 .collect::<Result<Vec<_>>>()
385 })
386 .await?;
387 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
388 this.update(&mut cx, |this, cx| {
389 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
390 existing_context
391 } else {
392 this.register_context(&context, cx);
393 this.synchronize_contexts(cx);
394 context
395 }
396 })
397 })
398 }
399
400 pub fn open_local_context(
401 &mut self,
402 path: PathBuf,
403 cx: &ModelContext<Self>,
404 ) -> Task<Result<Model<Context>>> {
405 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
406 return Task::ready(Ok(existing_context));
407 }
408
409 let fs = self.fs.clone();
410 let languages = self.languages.clone();
411 let project = self.project.clone();
412 let telemetry = self.telemetry.clone();
413 let load = cx.background_executor().spawn({
414 let path = path.clone();
415 async move {
416 let saved_context = fs.load(&path).await?;
417 SavedContext::from_json(&saved_context)
418 }
419 });
420
421 cx.spawn(|this, mut cx| async move {
422 let saved_context = load.await?;
423 let context = cx.new_model(|cx| {
424 Context::deserialize(
425 saved_context,
426 path.clone(),
427 languages,
428 Some(project),
429 Some(telemetry),
430 cx,
431 )
432 })?;
433 this.update(&mut cx, |this, cx| {
434 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
435 existing_context
436 } else {
437 this.register_context(&context, cx);
438 context
439 }
440 })
441 })
442 }
443
444 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
445 self.contexts.iter().find_map(|context| {
446 let context = context.upgrade()?;
447 if context.read(cx).path() == Some(path) {
448 Some(context)
449 } else {
450 None
451 }
452 })
453 }
454
455 pub(super) fn loaded_context_for_id(
456 &self,
457 id: &ContextId,
458 cx: &AppContext,
459 ) -> Option<Model<Context>> {
460 self.contexts.iter().find_map(|context| {
461 let context = context.upgrade()?;
462 if context.read(cx).id() == id {
463 Some(context)
464 } else {
465 None
466 }
467 })
468 }
469
470 pub fn open_remote_context(
471 &mut self,
472 context_id: ContextId,
473 cx: &mut ModelContext<Self>,
474 ) -> Task<Result<Model<Context>>> {
475 let project = self.project.read(cx);
476 let Some(project_id) = project.remote_id() else {
477 return Task::ready(Err(anyhow!("project was not remote")));
478 };
479 if project.is_local() {
480 return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
481 }
482
483 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
484 return Task::ready(Ok(context));
485 }
486
487 let replica_id = project.replica_id();
488 let capability = project.capability();
489 let language_registry = self.languages.clone();
490 let project = self.project.clone();
491 let telemetry = self.telemetry.clone();
492 let request = self.client.request(proto::OpenContext {
493 project_id,
494 context_id: context_id.to_proto(),
495 });
496 cx.spawn(|this, mut cx| async move {
497 let response = request.await?;
498 let context_proto = response.context.context("invalid context")?;
499 let context = cx.new_model(|cx| {
500 Context::new(
501 context_id.clone(),
502 replica_id,
503 capability,
504 language_registry,
505 Some(project),
506 Some(telemetry),
507 cx,
508 )
509 })?;
510 let operations = cx
511 .background_executor()
512 .spawn(async move {
513 context_proto
514 .operations
515 .into_iter()
516 .map(|op| ContextOperation::from_proto(op))
517 .collect::<Result<Vec<_>>>()
518 })
519 .await?;
520 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
521 this.update(&mut cx, |this, cx| {
522 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
523 existing_context
524 } else {
525 this.register_context(&context, cx);
526 this.synchronize_contexts(cx);
527 context
528 }
529 })
530 })
531 }
532
533 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
534 let handle = if self.project_is_shared {
535 ContextHandle::Strong(context.clone())
536 } else {
537 ContextHandle::Weak(context.downgrade())
538 };
539 self.contexts.push(handle);
540 self.advertise_contexts(cx);
541 cx.subscribe(context, Self::handle_context_event).detach();
542 }
543
544 fn handle_context_event(
545 &mut self,
546 context: Model<Context>,
547 event: &ContextEvent,
548 cx: &mut ModelContext<Self>,
549 ) {
550 let Some(project_id) = self.project.read(cx).remote_id() else {
551 return;
552 };
553
554 match event {
555 ContextEvent::SummaryChanged => {
556 self.advertise_contexts(cx);
557 }
558 ContextEvent::Operation(operation) => {
559 let context_id = context.read(cx).id().to_proto();
560 let operation = operation.to_proto();
561 self.client
562 .send(proto::UpdateContext {
563 project_id,
564 context_id,
565 operation: Some(operation),
566 })
567 .log_err();
568 }
569 _ => {}
570 }
571 }
572
573 fn advertise_contexts(&self, cx: &AppContext) {
574 let Some(project_id) = self.project.read(cx).remote_id() else {
575 return;
576 };
577
578 // For now, only the host can advertise their open contexts.
579 if self.project.read(cx).is_remote() {
580 return;
581 }
582
583 let contexts = self
584 .contexts
585 .iter()
586 .rev()
587 .filter_map(|context| {
588 let context = context.upgrade()?.read(cx);
589 if context.replica_id() == ReplicaId::default() {
590 Some(proto::ContextMetadata {
591 context_id: context.id().to_proto(),
592 summary: context.summary().map(|summary| summary.text.clone()),
593 })
594 } else {
595 None
596 }
597 })
598 .collect();
599 self.client
600 .send(proto::AdvertiseContexts {
601 project_id,
602 contexts,
603 })
604 .ok();
605 }
606
607 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
608 let Some(project_id) = self.project.read(cx).remote_id() else {
609 return;
610 };
611
612 let contexts = self
613 .contexts
614 .iter()
615 .filter_map(|context| {
616 let context = context.upgrade()?.read(cx);
617 if context.replica_id() != ReplicaId::default() {
618 Some(context.version(cx).to_proto(context.id().clone()))
619 } else {
620 None
621 }
622 })
623 .collect();
624
625 let client = self.client.clone();
626 let request = self.client.request(proto::SynchronizeContexts {
627 project_id,
628 contexts,
629 });
630 cx.spawn(|this, cx| async move {
631 let response = request.await?;
632
633 let mut context_ids = Vec::new();
634 let mut operations = Vec::new();
635 this.read_with(&cx, |this, cx| {
636 for context_version_proto in response.contexts {
637 let context_version = ContextVersion::from_proto(&context_version_proto);
638 let context_id = ContextId::from_proto(context_version_proto.context_id);
639 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
640 context_ids.push(context_id);
641 operations.push(context.read(cx).serialize_ops(&context_version, cx));
642 }
643 }
644 })?;
645
646 let operations = futures::future::join_all(operations).await;
647 for (context_id, operations) in context_ids.into_iter().zip(operations) {
648 for operation in operations {
649 client.send(proto::UpdateContext {
650 project_id,
651 context_id: context_id.to_proto(),
652 operation: Some(operation),
653 })?;
654 }
655 }
656
657 anyhow::Ok(())
658 })
659 .detach_and_log_err(cx);
660 }
661
662 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
663 let metadata = self.contexts_metadata.clone();
664 let executor = cx.background_executor().clone();
665 cx.background_executor().spawn(async move {
666 if query.is_empty() {
667 metadata
668 } else {
669 let candidates = metadata
670 .iter()
671 .enumerate()
672 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
673 .collect::<Vec<_>>();
674 let matches = fuzzy::match_strings(
675 &candidates,
676 &query,
677 false,
678 100,
679 &Default::default(),
680 executor,
681 )
682 .await;
683
684 matches
685 .into_iter()
686 .map(|mat| metadata[mat.candidate_id].clone())
687 .collect()
688 }
689 })
690 }
691
692 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
693 &self.host_contexts
694 }
695
696 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
697 let fs = self.fs.clone();
698 cx.spawn(|this, mut cx| async move {
699 fs.create_dir(contexts_dir()).await?;
700
701 let mut paths = fs.read_dir(contexts_dir()).await?;
702 let mut contexts = Vec::<SavedContextMetadata>::new();
703 while let Some(path) = paths.next().await {
704 let path = path?;
705 if path.extension() != Some(OsStr::new("json")) {
706 continue;
707 }
708
709 let pattern = r" - \d+.zed.json$";
710 let re = Regex::new(pattern).unwrap();
711
712 let metadata = fs.metadata(&path).await?;
713 if let Some((file_name, metadata)) = path
714 .file_name()
715 .and_then(|name| name.to_str())
716 .zip(metadata)
717 {
718 // This is used to filter out contexts saved by the new assistant.
719 if !re.is_match(file_name) {
720 continue;
721 }
722
723 if let Some(title) = re.replace(file_name, "").lines().next() {
724 contexts.push(SavedContextMetadata {
725 title: title.to_string(),
726 path,
727 mtime: metadata.mtime.into(),
728 });
729 }
730 }
731 }
732 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
733
734 this.update(&mut cx, |this, cx| {
735 this.contexts_metadata = contexts;
736 cx.notify();
737 })
738 })
739 }
740}