1use crate::{
2 Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
7use clock::ReplicaId;
8use fs::Fs;
9use futures::StreamExt;
10use fuzzy::StringMatchCandidate;
11use gpui::{
12 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
13};
14use language::LanguageRegistry;
15use paths::contexts_dir;
16use project::Project;
17use regex::Regex;
18use std::{
19 cmp::Reverse,
20 ffi::OsStr,
21 mem,
22 path::{Path, PathBuf},
23 sync::Arc,
24 time::Duration,
25};
26use util::{ResultExt, TryFutureExt};
27
28pub fn init(client: &Arc<Client>) {
29 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
30 client.add_model_request_handler(ContextStore::handle_open_context);
31 client.add_model_request_handler(ContextStore::handle_create_context);
32 client.add_model_message_handler(ContextStore::handle_update_context);
33 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
34}
35
36#[derive(Clone)]
37pub struct RemoteContextMetadata {
38 pub id: ContextId,
39 pub summary: Option<String>,
40}
41
42pub struct ContextStore {
43 contexts: Vec<ContextHandle>,
44 contexts_metadata: Vec<SavedContextMetadata>,
45 host_contexts: Vec<RemoteContextMetadata>,
46 fs: Arc<dyn Fs>,
47 languages: Arc<LanguageRegistry>,
48 telemetry: Arc<Telemetry>,
49 _watch_updates: Task<Option<()>>,
50 client: Arc<Client>,
51 project: Model<Project>,
52 project_is_shared: bool,
53 client_subscription: Option<client::Subscription>,
54 _project_subscriptions: Vec<gpui::Subscription>,
55}
56
57pub enum ContextStoreEvent {
58 ContextCreated(ContextId),
59}
60
61impl EventEmitter<ContextStoreEvent> for ContextStore {}
62
63enum ContextHandle {
64 Weak(WeakModel<Context>),
65 Strong(Model<Context>),
66}
67
68impl ContextHandle {
69 fn upgrade(&self) -> Option<Model<Context>> {
70 match self {
71 ContextHandle::Weak(weak) => weak.upgrade(),
72 ContextHandle::Strong(strong) => Some(strong.clone()),
73 }
74 }
75
76 fn downgrade(&self) -> WeakModel<Context> {
77 match self {
78 ContextHandle::Weak(weak) => weak.clone(),
79 ContextHandle::Strong(strong) => strong.downgrade(),
80 }
81 }
82}
83
84impl ContextStore {
85 pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
86 let fs = project.read(cx).fs().clone();
87 let languages = project.read(cx).languages().clone();
88 let telemetry = project.read(cx).client().telemetry().clone();
89 cx.spawn(|mut cx| async move {
90 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
91 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
92
93 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
94 let mut this = Self {
95 contexts: Vec::new(),
96 contexts_metadata: Vec::new(),
97 host_contexts: Vec::new(),
98 fs,
99 languages,
100 telemetry,
101 _watch_updates: cx.spawn(|this, mut cx| {
102 async move {
103 while events.next().await.is_some() {
104 this.update(&mut cx, |this, cx| this.reload(cx))?
105 .await
106 .log_err();
107 }
108 anyhow::Ok(())
109 }
110 .log_err()
111 }),
112 client_subscription: None,
113 _project_subscriptions: vec![
114 cx.observe(&project, Self::handle_project_changed),
115 cx.subscribe(&project, Self::handle_project_event),
116 ],
117 project_is_shared: false,
118 client: project.read(cx).client(),
119 project: project.clone(),
120 };
121 this.handle_project_changed(project, cx);
122 this.synchronize_contexts(cx);
123 this
124 })?;
125 this.update(&mut cx, |this, cx| this.reload(cx))?
126 .await
127 .log_err();
128 Ok(this)
129 })
130 }
131
132 async fn handle_advertise_contexts(
133 this: Model<Self>,
134 envelope: TypedEnvelope<proto::AdvertiseContexts>,
135 mut cx: AsyncAppContext,
136 ) -> Result<()> {
137 this.update(&mut cx, |this, cx| {
138 this.host_contexts = envelope
139 .payload
140 .contexts
141 .into_iter()
142 .map(|context| RemoteContextMetadata {
143 id: ContextId::from_proto(context.context_id),
144 summary: context.summary,
145 })
146 .collect();
147 cx.notify();
148 })
149 }
150
151 async fn handle_open_context(
152 this: Model<Self>,
153 envelope: TypedEnvelope<proto::OpenContext>,
154 mut cx: AsyncAppContext,
155 ) -> Result<proto::OpenContextResponse> {
156 let context_id = ContextId::from_proto(envelope.payload.context_id);
157 let operations = this.update(&mut cx, |this, cx| {
158 if this.project.read(cx).is_remote() {
159 return Err(anyhow!("only the host contexts can be opened"));
160 }
161
162 let context = this
163 .loaded_context_for_id(&context_id, cx)
164 .context("context not found")?;
165 if context.read(cx).replica_id() != ReplicaId::default() {
166 return Err(anyhow!("context must be opened via the host"));
167 }
168
169 anyhow::Ok(
170 context
171 .read(cx)
172 .serialize_ops(&ContextVersion::default(), cx),
173 )
174 })??;
175 let operations = operations.await;
176 Ok(proto::OpenContextResponse {
177 context: Some(proto::Context { operations }),
178 })
179 }
180
181 async fn handle_create_context(
182 this: Model<Self>,
183 _: TypedEnvelope<proto::CreateContext>,
184 mut cx: AsyncAppContext,
185 ) -> Result<proto::CreateContextResponse> {
186 let (context_id, operations) = this.update(&mut cx, |this, cx| {
187 if this.project.read(cx).is_remote() {
188 return Err(anyhow!("can only create contexts as the host"));
189 }
190
191 let context = this.create(cx);
192 let context_id = context.read(cx).id().clone();
193 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
194
195 anyhow::Ok((
196 context_id,
197 context
198 .read(cx)
199 .serialize_ops(&ContextVersion::default(), cx),
200 ))
201 })??;
202 let operations = operations.await;
203 Ok(proto::CreateContextResponse {
204 context_id: context_id.to_proto(),
205 context: Some(proto::Context { operations }),
206 })
207 }
208
209 async fn handle_update_context(
210 this: Model<Self>,
211 envelope: TypedEnvelope<proto::UpdateContext>,
212 mut cx: AsyncAppContext,
213 ) -> Result<()> {
214 this.update(&mut cx, |this, cx| {
215 let context_id = ContextId::from_proto(envelope.payload.context_id);
216 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
217 let operation_proto = envelope.payload.operation.context("invalid operation")?;
218 let operation = ContextOperation::from_proto(operation_proto)?;
219 context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
220 }
221 Ok(())
222 })?
223 }
224
225 async fn handle_synchronize_contexts(
226 this: Model<Self>,
227 envelope: TypedEnvelope<proto::SynchronizeContexts>,
228 mut cx: AsyncAppContext,
229 ) -> Result<proto::SynchronizeContextsResponse> {
230 this.update(&mut cx, |this, cx| {
231 if this.project.read(cx).is_remote() {
232 return Err(anyhow!("only the host can synchronize contexts"));
233 }
234
235 let mut local_versions = Vec::new();
236 for remote_version_proto in envelope.payload.contexts {
237 let remote_version = ContextVersion::from_proto(&remote_version_proto);
238 let context_id = ContextId::from_proto(remote_version_proto.context_id);
239 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
240 let context = context.read(cx);
241 let operations = context.serialize_ops(&remote_version, cx);
242 local_versions.push(context.version(cx).to_proto(context_id.clone()));
243 let client = this.client.clone();
244 let project_id = envelope.payload.project_id;
245 cx.background_executor()
246 .spawn(async move {
247 let operations = operations.await;
248 for operation in operations {
249 client.send(proto::UpdateContext {
250 project_id,
251 context_id: context_id.to_proto(),
252 operation: Some(operation),
253 })?;
254 }
255 anyhow::Ok(())
256 })
257 .detach_and_log_err(cx);
258 }
259 }
260
261 this.advertise_contexts(cx);
262
263 anyhow::Ok(proto::SynchronizeContextsResponse {
264 contexts: local_versions,
265 })
266 })?
267 }
268
269 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
270 let is_shared = self.project.read(cx).is_shared();
271 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
272 if is_shared == was_shared {
273 return;
274 }
275
276 if is_shared {
277 self.contexts.retain_mut(|context| {
278 if let Some(strong_context) = context.upgrade() {
279 *context = ContextHandle::Strong(strong_context);
280 true
281 } else {
282 false
283 }
284 });
285 let remote_id = self.project.read(cx).remote_id().unwrap();
286 self.client_subscription = self
287 .client
288 .subscribe_to_entity(remote_id)
289 .log_err()
290 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
291 self.advertise_contexts(cx);
292 } else {
293 self.client_subscription = None;
294 }
295 }
296
297 fn handle_project_event(
298 &mut self,
299 _: Model<Project>,
300 event: &project::Event,
301 cx: &mut ModelContext<Self>,
302 ) {
303 match event {
304 project::Event::Reshared => {
305 self.advertise_contexts(cx);
306 }
307 project::Event::HostReshared | project::Event::Rejoined => {
308 self.synchronize_contexts(cx);
309 }
310 project::Event::DisconnectedFromHost => {
311 self.contexts.retain_mut(|context| {
312 if let Some(strong_context) = context.upgrade() {
313 *context = ContextHandle::Weak(context.downgrade());
314 strong_context.update(cx, |context, cx| {
315 if context.replica_id() != ReplicaId::default() {
316 context.set_capability(language::Capability::ReadOnly, cx);
317 }
318 });
319 true
320 } else {
321 false
322 }
323 });
324 self.host_contexts.clear();
325 cx.notify();
326 }
327 _ => {}
328 }
329 }
330
331 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
332 let context = cx.new_model(|cx| {
333 Context::local(self.languages.clone(), Some(self.telemetry.clone()), cx)
334 });
335 self.register_context(&context, cx);
336 context
337 }
338
339 pub fn create_remote_context(
340 &mut self,
341 cx: &mut ModelContext<Self>,
342 ) -> Task<Result<Model<Context>>> {
343 let project = self.project.read(cx);
344 let Some(project_id) = project.remote_id() else {
345 return Task::ready(Err(anyhow!("project was not remote")));
346 };
347 if project.is_local() {
348 return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
349 }
350
351 let replica_id = project.replica_id();
352 let capability = project.capability();
353 let language_registry = self.languages.clone();
354 let telemetry = self.telemetry.clone();
355 let request = self.client.request(proto::CreateContext { project_id });
356 cx.spawn(|this, mut cx| async move {
357 let response = request.await?;
358 let context_id = ContextId::from_proto(response.context_id);
359 let context_proto = response.context.context("invalid context")?;
360 let context = cx.new_model(|cx| {
361 Context::new(
362 context_id.clone(),
363 replica_id,
364 capability,
365 language_registry,
366 Some(telemetry),
367 cx,
368 )
369 })?;
370 let operations = cx
371 .background_executor()
372 .spawn(async move {
373 context_proto
374 .operations
375 .into_iter()
376 .map(|op| ContextOperation::from_proto(op))
377 .collect::<Result<Vec<_>>>()
378 })
379 .await?;
380 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
381 this.update(&mut cx, |this, cx| {
382 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
383 existing_context
384 } else {
385 this.register_context(&context, cx);
386 this.synchronize_contexts(cx);
387 context
388 }
389 })
390 })
391 }
392
393 pub fn open_local_context(
394 &mut self,
395 path: PathBuf,
396 cx: &ModelContext<Self>,
397 ) -> Task<Result<Model<Context>>> {
398 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
399 return Task::ready(Ok(existing_context));
400 }
401
402 let fs = self.fs.clone();
403 let languages = self.languages.clone();
404 let telemetry = self.telemetry.clone();
405 let load = cx.background_executor().spawn({
406 let path = path.clone();
407 async move {
408 let saved_context = fs.load(&path).await?;
409 SavedContext::from_json(&saved_context)
410 }
411 });
412
413 cx.spawn(|this, mut cx| async move {
414 let saved_context = load.await?;
415 let context = cx.new_model(|cx| {
416 Context::deserialize(saved_context, path.clone(), languages, Some(telemetry), cx)
417 })?;
418 this.update(&mut cx, |this, cx| {
419 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
420 existing_context
421 } else {
422 this.register_context(&context, cx);
423 context
424 }
425 })
426 })
427 }
428
429 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
430 self.contexts.iter().find_map(|context| {
431 let context = context.upgrade()?;
432 if context.read(cx).path() == Some(path) {
433 Some(context)
434 } else {
435 None
436 }
437 })
438 }
439
440 pub(super) fn loaded_context_for_id(
441 &self,
442 id: &ContextId,
443 cx: &AppContext,
444 ) -> Option<Model<Context>> {
445 self.contexts.iter().find_map(|context| {
446 let context = context.upgrade()?;
447 if context.read(cx).id() == id {
448 Some(context)
449 } else {
450 None
451 }
452 })
453 }
454
455 pub fn open_remote_context(
456 &mut self,
457 context_id: ContextId,
458 cx: &mut ModelContext<Self>,
459 ) -> Task<Result<Model<Context>>> {
460 let project = self.project.read(cx);
461 let Some(project_id) = project.remote_id() else {
462 return Task::ready(Err(anyhow!("project was not remote")));
463 };
464 if project.is_local() {
465 return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
466 }
467
468 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
469 return Task::ready(Ok(context));
470 }
471
472 let replica_id = project.replica_id();
473 let capability = project.capability();
474 let language_registry = self.languages.clone();
475 let telemetry = self.telemetry.clone();
476 let request = self.client.request(proto::OpenContext {
477 project_id,
478 context_id: context_id.to_proto(),
479 });
480 cx.spawn(|this, mut cx| async move {
481 let response = request.await?;
482 let context_proto = response.context.context("invalid context")?;
483 let context = cx.new_model(|cx| {
484 Context::new(
485 context_id.clone(),
486 replica_id,
487 capability,
488 language_registry,
489 Some(telemetry),
490 cx,
491 )
492 })?;
493 let operations = cx
494 .background_executor()
495 .spawn(async move {
496 context_proto
497 .operations
498 .into_iter()
499 .map(|op| ContextOperation::from_proto(op))
500 .collect::<Result<Vec<_>>>()
501 })
502 .await?;
503 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
504 this.update(&mut cx, |this, cx| {
505 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
506 existing_context
507 } else {
508 this.register_context(&context, cx);
509 this.synchronize_contexts(cx);
510 context
511 }
512 })
513 })
514 }
515
516 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
517 let handle = if self.project_is_shared {
518 ContextHandle::Strong(context.clone())
519 } else {
520 ContextHandle::Weak(context.downgrade())
521 };
522 self.contexts.push(handle);
523 self.advertise_contexts(cx);
524 cx.subscribe(context, Self::handle_context_event).detach();
525 }
526
527 fn handle_context_event(
528 &mut self,
529 context: Model<Context>,
530 event: &ContextEvent,
531 cx: &mut ModelContext<Self>,
532 ) {
533 let Some(project_id) = self.project.read(cx).remote_id() else {
534 return;
535 };
536
537 match event {
538 ContextEvent::SummaryChanged => {
539 self.advertise_contexts(cx);
540 }
541 ContextEvent::Operation(operation) => {
542 let context_id = context.read(cx).id().to_proto();
543 let operation = operation.to_proto();
544 self.client
545 .send(proto::UpdateContext {
546 project_id,
547 context_id,
548 operation: Some(operation),
549 })
550 .log_err();
551 }
552 _ => {}
553 }
554 }
555
556 fn advertise_contexts(&self, cx: &AppContext) {
557 let Some(project_id) = self.project.read(cx).remote_id() else {
558 return;
559 };
560
561 // For now, only the host can advertise their open contexts.
562 if self.project.read(cx).is_remote() {
563 return;
564 }
565
566 let contexts = self
567 .contexts
568 .iter()
569 .rev()
570 .filter_map(|context| {
571 let context = context.upgrade()?.read(cx);
572 if context.replica_id() == ReplicaId::default() {
573 Some(proto::ContextMetadata {
574 context_id: context.id().to_proto(),
575 summary: context.summary().map(|summary| summary.text.clone()),
576 })
577 } else {
578 None
579 }
580 })
581 .collect();
582 self.client
583 .send(proto::AdvertiseContexts {
584 project_id,
585 contexts,
586 })
587 .ok();
588 }
589
590 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
591 let Some(project_id) = self.project.read(cx).remote_id() else {
592 return;
593 };
594
595 let contexts = self
596 .contexts
597 .iter()
598 .filter_map(|context| {
599 let context = context.upgrade()?.read(cx);
600 if context.replica_id() != ReplicaId::default() {
601 Some(context.version(cx).to_proto(context.id().clone()))
602 } else {
603 None
604 }
605 })
606 .collect();
607
608 let client = self.client.clone();
609 let request = self.client.request(proto::SynchronizeContexts {
610 project_id,
611 contexts,
612 });
613 cx.spawn(|this, cx| async move {
614 let response = request.await?;
615
616 let mut context_ids = Vec::new();
617 let mut operations = Vec::new();
618 this.read_with(&cx, |this, cx| {
619 for context_version_proto in response.contexts {
620 let context_version = ContextVersion::from_proto(&context_version_proto);
621 let context_id = ContextId::from_proto(context_version_proto.context_id);
622 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
623 context_ids.push(context_id);
624 operations.push(context.read(cx).serialize_ops(&context_version, cx));
625 }
626 }
627 })?;
628
629 let operations = futures::future::join_all(operations).await;
630 for (context_id, operations) in context_ids.into_iter().zip(operations) {
631 for operation in operations {
632 client.send(proto::UpdateContext {
633 project_id,
634 context_id: context_id.to_proto(),
635 operation: Some(operation),
636 })?;
637 }
638 }
639
640 anyhow::Ok(())
641 })
642 .detach_and_log_err(cx);
643 }
644
645 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
646 let metadata = self.contexts_metadata.clone();
647 let executor = cx.background_executor().clone();
648 cx.background_executor().spawn(async move {
649 if query.is_empty() {
650 metadata
651 } else {
652 let candidates = metadata
653 .iter()
654 .enumerate()
655 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
656 .collect::<Vec<_>>();
657 let matches = fuzzy::match_strings(
658 &candidates,
659 &query,
660 false,
661 100,
662 &Default::default(),
663 executor,
664 )
665 .await;
666
667 matches
668 .into_iter()
669 .map(|mat| metadata[mat.candidate_id].clone())
670 .collect()
671 }
672 })
673 }
674
675 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
676 &self.host_contexts
677 }
678
679 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
680 let fs = self.fs.clone();
681 cx.spawn(|this, mut cx| async move {
682 fs.create_dir(contexts_dir()).await?;
683
684 let mut paths = fs.read_dir(contexts_dir()).await?;
685 let mut contexts = Vec::<SavedContextMetadata>::new();
686 while let Some(path) = paths.next().await {
687 let path = path?;
688 if path.extension() != Some(OsStr::new("json")) {
689 continue;
690 }
691
692 let pattern = r" - \d+.zed.json$";
693 let re = Regex::new(pattern).unwrap();
694
695 let metadata = fs.metadata(&path).await?;
696 if let Some((file_name, metadata)) = path
697 .file_name()
698 .and_then(|name| name.to_str())
699 .zip(metadata)
700 {
701 // This is used to filter out contexts saved by the new assistant.
702 if !re.is_match(file_name) {
703 continue;
704 }
705
706 if let Some(title) = re.replace(file_name, "").lines().next() {
707 contexts.push(SavedContextMetadata {
708 title: title.to_string(),
709 path,
710 mtime: metadata.mtime.into(),
711 });
712 }
713 }
714 }
715 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
716
717 this.update(&mut cx, |this, cx| {
718 this.contexts_metadata = contexts;
719 cx.notify();
720 })
721 })
722 }
723}