context_store.rs

  1use crate::{
  2    Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
  3    SavedContextMetadata,
  4};
  5use anyhow::{anyhow, Context as _, Result};
  6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
  7use clock::ReplicaId;
  8use fs::Fs;
  9use futures::StreamExt;
 10use fuzzy::StringMatchCandidate;
 11use gpui::{
 12    AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
 13};
 14use language::LanguageRegistry;
 15use paths::contexts_dir;
 16use project::Project;
 17use regex::Regex;
 18use std::{
 19    cmp::Reverse,
 20    ffi::OsStr,
 21    mem,
 22    path::{Path, PathBuf},
 23    sync::Arc,
 24    time::Duration,
 25};
 26use util::{ResultExt, TryFutureExt};
 27
 28pub fn init(client: &Arc<Client>) {
 29    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 30    client.add_model_request_handler(ContextStore::handle_open_context);
 31    client.add_model_request_handler(ContextStore::handle_create_context);
 32    client.add_model_message_handler(ContextStore::handle_update_context);
 33    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 34}
 35
 36#[derive(Clone)]
 37pub struct RemoteContextMetadata {
 38    pub id: ContextId,
 39    pub summary: Option<String>,
 40}
 41
 42pub struct ContextStore {
 43    contexts: Vec<ContextHandle>,
 44    contexts_metadata: Vec<SavedContextMetadata>,
 45    host_contexts: Vec<RemoteContextMetadata>,
 46    fs: Arc<dyn Fs>,
 47    languages: Arc<LanguageRegistry>,
 48    telemetry: Arc<Telemetry>,
 49    _watch_updates: Task<Option<()>>,
 50    client: Arc<Client>,
 51    project: Model<Project>,
 52    project_is_shared: bool,
 53    client_subscription: Option<client::Subscription>,
 54    _project_subscriptions: Vec<gpui::Subscription>,
 55}
 56
 57pub enum ContextStoreEvent {
 58    ContextCreated(ContextId),
 59}
 60
 61impl EventEmitter<ContextStoreEvent> for ContextStore {}
 62
 63enum ContextHandle {
 64    Weak(WeakModel<Context>),
 65    Strong(Model<Context>),
 66}
 67
 68impl ContextHandle {
 69    fn upgrade(&self) -> Option<Model<Context>> {
 70        match self {
 71            ContextHandle::Weak(weak) => weak.upgrade(),
 72            ContextHandle::Strong(strong) => Some(strong.clone()),
 73        }
 74    }
 75
 76    fn downgrade(&self) -> WeakModel<Context> {
 77        match self {
 78            ContextHandle::Weak(weak) => weak.clone(),
 79            ContextHandle::Strong(strong) => strong.downgrade(),
 80        }
 81    }
 82}
 83
 84impl ContextStore {
 85    pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
 86        let fs = project.read(cx).fs().clone();
 87        let languages = project.read(cx).languages().clone();
 88        let telemetry = project.read(cx).client().telemetry().clone();
 89        cx.spawn(|mut cx| async move {
 90            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 91            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
 92
 93            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
 94                let mut this = Self {
 95                    contexts: Vec::new(),
 96                    contexts_metadata: Vec::new(),
 97                    host_contexts: Vec::new(),
 98                    fs,
 99                    languages,
100                    telemetry,
101                    _watch_updates: cx.spawn(|this, mut cx| {
102                        async move {
103                            while events.next().await.is_some() {
104                                this.update(&mut cx, |this, cx| this.reload(cx))?
105                                    .await
106                                    .log_err();
107                            }
108                            anyhow::Ok(())
109                        }
110                        .log_err()
111                    }),
112                    client_subscription: None,
113                    _project_subscriptions: vec![
114                        cx.observe(&project, Self::handle_project_changed),
115                        cx.subscribe(&project, Self::handle_project_event),
116                    ],
117                    project_is_shared: false,
118                    client: project.read(cx).client(),
119                    project: project.clone(),
120                };
121                this.handle_project_changed(project, cx);
122                this.synchronize_contexts(cx);
123                this
124            })?;
125            this.update(&mut cx, |this, cx| this.reload(cx))?
126                .await
127                .log_err();
128            Ok(this)
129        })
130    }
131
132    async fn handle_advertise_contexts(
133        this: Model<Self>,
134        envelope: TypedEnvelope<proto::AdvertiseContexts>,
135        mut cx: AsyncAppContext,
136    ) -> Result<()> {
137        this.update(&mut cx, |this, cx| {
138            this.host_contexts = envelope
139                .payload
140                .contexts
141                .into_iter()
142                .map(|context| RemoteContextMetadata {
143                    id: ContextId::from_proto(context.context_id),
144                    summary: context.summary,
145                })
146                .collect();
147            cx.notify();
148        })
149    }
150
151    async fn handle_open_context(
152        this: Model<Self>,
153        envelope: TypedEnvelope<proto::OpenContext>,
154        mut cx: AsyncAppContext,
155    ) -> Result<proto::OpenContextResponse> {
156        let context_id = ContextId::from_proto(envelope.payload.context_id);
157        let operations = this.update(&mut cx, |this, cx| {
158            if this.project.read(cx).is_remote() {
159                return Err(anyhow!("only the host contexts can be opened"));
160            }
161
162            let context = this
163                .loaded_context_for_id(&context_id, cx)
164                .context("context not found")?;
165            if context.read(cx).replica_id() != ReplicaId::default() {
166                return Err(anyhow!("context must be opened via the host"));
167            }
168
169            anyhow::Ok(
170                context
171                    .read(cx)
172                    .serialize_ops(&ContextVersion::default(), cx),
173            )
174        })??;
175        let operations = operations.await;
176        Ok(proto::OpenContextResponse {
177            context: Some(proto::Context { operations }),
178        })
179    }
180
181    async fn handle_create_context(
182        this: Model<Self>,
183        _: TypedEnvelope<proto::CreateContext>,
184        mut cx: AsyncAppContext,
185    ) -> Result<proto::CreateContextResponse> {
186        let (context_id, operations) = this.update(&mut cx, |this, cx| {
187            if this.project.read(cx).is_remote() {
188                return Err(anyhow!("can only create contexts as the host"));
189            }
190
191            let context = this.create(cx);
192            let context_id = context.read(cx).id().clone();
193            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
194
195            anyhow::Ok((
196                context_id,
197                context
198                    .read(cx)
199                    .serialize_ops(&ContextVersion::default(), cx),
200            ))
201        })??;
202        let operations = operations.await;
203        Ok(proto::CreateContextResponse {
204            context_id: context_id.to_proto(),
205            context: Some(proto::Context { operations }),
206        })
207    }
208
209    async fn handle_update_context(
210        this: Model<Self>,
211        envelope: TypedEnvelope<proto::UpdateContext>,
212        mut cx: AsyncAppContext,
213    ) -> Result<()> {
214        this.update(&mut cx, |this, cx| {
215            let context_id = ContextId::from_proto(envelope.payload.context_id);
216            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
217                let operation_proto = envelope.payload.operation.context("invalid operation")?;
218                let operation = ContextOperation::from_proto(operation_proto)?;
219                context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
220            }
221            Ok(())
222        })?
223    }
224
225    async fn handle_synchronize_contexts(
226        this: Model<Self>,
227        envelope: TypedEnvelope<proto::SynchronizeContexts>,
228        mut cx: AsyncAppContext,
229    ) -> Result<proto::SynchronizeContextsResponse> {
230        this.update(&mut cx, |this, cx| {
231            if this.project.read(cx).is_remote() {
232                return Err(anyhow!("only the host can synchronize contexts"));
233            }
234
235            let mut local_versions = Vec::new();
236            for remote_version_proto in envelope.payload.contexts {
237                let remote_version = ContextVersion::from_proto(&remote_version_proto);
238                let context_id = ContextId::from_proto(remote_version_proto.context_id);
239                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
240                    let context = context.read(cx);
241                    let operations = context.serialize_ops(&remote_version, cx);
242                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
243                    let client = this.client.clone();
244                    let project_id = envelope.payload.project_id;
245                    cx.background_executor()
246                        .spawn(async move {
247                            let operations = operations.await;
248                            for operation in operations {
249                                client.send(proto::UpdateContext {
250                                    project_id,
251                                    context_id: context_id.to_proto(),
252                                    operation: Some(operation),
253                                })?;
254                            }
255                            anyhow::Ok(())
256                        })
257                        .detach_and_log_err(cx);
258                }
259            }
260
261            this.advertise_contexts(cx);
262
263            anyhow::Ok(proto::SynchronizeContextsResponse {
264                contexts: local_versions,
265            })
266        })?
267    }
268
269    fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
270        let is_shared = self.project.read(cx).is_shared();
271        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
272        if is_shared == was_shared {
273            return;
274        }
275
276        if is_shared {
277            self.contexts.retain_mut(|context| {
278                if let Some(strong_context) = context.upgrade() {
279                    *context = ContextHandle::Strong(strong_context);
280                    true
281                } else {
282                    false
283                }
284            });
285            let remote_id = self.project.read(cx).remote_id().unwrap();
286            self.client_subscription = self
287                .client
288                .subscribe_to_entity(remote_id)
289                .log_err()
290                .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
291            self.advertise_contexts(cx);
292        } else {
293            self.client_subscription = None;
294        }
295    }
296
297    fn handle_project_event(
298        &mut self,
299        _: Model<Project>,
300        event: &project::Event,
301        cx: &mut ModelContext<Self>,
302    ) {
303        match event {
304            project::Event::Reshared => {
305                self.advertise_contexts(cx);
306            }
307            project::Event::HostReshared | project::Event::Rejoined => {
308                self.synchronize_contexts(cx);
309            }
310            project::Event::DisconnectedFromHost => {
311                self.contexts.retain_mut(|context| {
312                    if let Some(strong_context) = context.upgrade() {
313                        *context = ContextHandle::Weak(context.downgrade());
314                        strong_context.update(cx, |context, cx| {
315                            if context.replica_id() != ReplicaId::default() {
316                                context.set_capability(language::Capability::ReadOnly, cx);
317                            }
318                        });
319                        true
320                    } else {
321                        false
322                    }
323                });
324                self.host_contexts.clear();
325                cx.notify();
326            }
327            _ => {}
328        }
329    }
330
331    pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
332        let context = cx.new_model(|cx| {
333            Context::local(
334                self.languages.clone(),
335                Some(self.project.clone()),
336                Some(self.telemetry.clone()),
337                cx,
338            )
339        });
340        self.register_context(&context, cx);
341        context
342    }
343
344    pub fn create_remote_context(
345        &mut self,
346        cx: &mut ModelContext<Self>,
347    ) -> Task<Result<Model<Context>>> {
348        let project = self.project.read(cx);
349        let Some(project_id) = project.remote_id() else {
350            return Task::ready(Err(anyhow!("project was not remote")));
351        };
352        if project.is_local() {
353            return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
354        }
355
356        let replica_id = project.replica_id();
357        let capability = project.capability();
358        let language_registry = self.languages.clone();
359        let project = self.project.clone();
360        let telemetry = self.telemetry.clone();
361        let request = self.client.request(proto::CreateContext { project_id });
362        cx.spawn(|this, mut cx| async move {
363            let response = request.await?;
364            let context_id = ContextId::from_proto(response.context_id);
365            let context_proto = response.context.context("invalid context")?;
366            let context = cx.new_model(|cx| {
367                Context::new(
368                    context_id.clone(),
369                    replica_id,
370                    capability,
371                    language_registry,
372                    Some(project),
373                    Some(telemetry),
374                    cx,
375                )
376            })?;
377            let operations = cx
378                .background_executor()
379                .spawn(async move {
380                    context_proto
381                        .operations
382                        .into_iter()
383                        .map(|op| ContextOperation::from_proto(op))
384                        .collect::<Result<Vec<_>>>()
385                })
386                .await?;
387            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
388            this.update(&mut cx, |this, cx| {
389                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
390                    existing_context
391                } else {
392                    this.register_context(&context, cx);
393                    this.synchronize_contexts(cx);
394                    context
395                }
396            })
397        })
398    }
399
400    pub fn open_local_context(
401        &mut self,
402        path: PathBuf,
403        cx: &ModelContext<Self>,
404    ) -> Task<Result<Model<Context>>> {
405        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
406            return Task::ready(Ok(existing_context));
407        }
408
409        let fs = self.fs.clone();
410        let languages = self.languages.clone();
411        let project = self.project.clone();
412        let telemetry = self.telemetry.clone();
413        let load = cx.background_executor().spawn({
414            let path = path.clone();
415            async move {
416                let saved_context = fs.load(&path).await?;
417                SavedContext::from_json(&saved_context)
418            }
419        });
420
421        cx.spawn(|this, mut cx| async move {
422            let saved_context = load.await?;
423            let context = cx.new_model(|cx| {
424                Context::deserialize(
425                    saved_context,
426                    path.clone(),
427                    languages,
428                    Some(project),
429                    Some(telemetry),
430                    cx,
431                )
432            })?;
433            this.update(&mut cx, |this, cx| {
434                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
435                    existing_context
436                } else {
437                    this.register_context(&context, cx);
438                    context
439                }
440            })
441        })
442    }
443
444    fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
445        self.contexts.iter().find_map(|context| {
446            let context = context.upgrade()?;
447            if context.read(cx).path() == Some(path) {
448                Some(context)
449            } else {
450                None
451            }
452        })
453    }
454
455    pub(super) fn loaded_context_for_id(
456        &self,
457        id: &ContextId,
458        cx: &AppContext,
459    ) -> Option<Model<Context>> {
460        self.contexts.iter().find_map(|context| {
461            let context = context.upgrade()?;
462            if context.read(cx).id() == id {
463                Some(context)
464            } else {
465                None
466            }
467        })
468    }
469
470    pub fn open_remote_context(
471        &mut self,
472        context_id: ContextId,
473        cx: &mut ModelContext<Self>,
474    ) -> Task<Result<Model<Context>>> {
475        let project = self.project.read(cx);
476        let Some(project_id) = project.remote_id() else {
477            return Task::ready(Err(anyhow!("project was not remote")));
478        };
479        if project.is_local() {
480            return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
481        }
482
483        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
484            return Task::ready(Ok(context));
485        }
486
487        let replica_id = project.replica_id();
488        let capability = project.capability();
489        let language_registry = self.languages.clone();
490        let project = self.project.clone();
491        let telemetry = self.telemetry.clone();
492        let request = self.client.request(proto::OpenContext {
493            project_id,
494            context_id: context_id.to_proto(),
495        });
496        cx.spawn(|this, mut cx| async move {
497            let response = request.await?;
498            let context_proto = response.context.context("invalid context")?;
499            let context = cx.new_model(|cx| {
500                Context::new(
501                    context_id.clone(),
502                    replica_id,
503                    capability,
504                    language_registry,
505                    Some(project),
506                    Some(telemetry),
507                    cx,
508                )
509            })?;
510            let operations = cx
511                .background_executor()
512                .spawn(async move {
513                    context_proto
514                        .operations
515                        .into_iter()
516                        .map(|op| ContextOperation::from_proto(op))
517                        .collect::<Result<Vec<_>>>()
518                })
519                .await?;
520            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
521            this.update(&mut cx, |this, cx| {
522                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
523                    existing_context
524                } else {
525                    this.register_context(&context, cx);
526                    this.synchronize_contexts(cx);
527                    context
528                }
529            })
530        })
531    }
532
533    fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
534        let handle = if self.project_is_shared {
535            ContextHandle::Strong(context.clone())
536        } else {
537            ContextHandle::Weak(context.downgrade())
538        };
539        self.contexts.push(handle);
540        self.advertise_contexts(cx);
541        cx.subscribe(context, Self::handle_context_event).detach();
542    }
543
544    fn handle_context_event(
545        &mut self,
546        context: Model<Context>,
547        event: &ContextEvent,
548        cx: &mut ModelContext<Self>,
549    ) {
550        let Some(project_id) = self.project.read(cx).remote_id() else {
551            return;
552        };
553
554        match event {
555            ContextEvent::SummaryChanged => {
556                self.advertise_contexts(cx);
557            }
558            ContextEvent::Operation(operation) => {
559                let context_id = context.read(cx).id().to_proto();
560                let operation = operation.to_proto();
561                self.client
562                    .send(proto::UpdateContext {
563                        project_id,
564                        context_id,
565                        operation: Some(operation),
566                    })
567                    .log_err();
568            }
569            _ => {}
570        }
571    }
572
573    fn advertise_contexts(&self, cx: &AppContext) {
574        let Some(project_id) = self.project.read(cx).remote_id() else {
575            return;
576        };
577
578        // For now, only the host can advertise their open contexts.
579        if self.project.read(cx).is_remote() {
580            return;
581        }
582
583        let contexts = self
584            .contexts
585            .iter()
586            .rev()
587            .filter_map(|context| {
588                let context = context.upgrade()?.read(cx);
589                if context.replica_id() == ReplicaId::default() {
590                    Some(proto::ContextMetadata {
591                        context_id: context.id().to_proto(),
592                        summary: context.summary().map(|summary| summary.text.clone()),
593                    })
594                } else {
595                    None
596                }
597            })
598            .collect();
599        self.client
600            .send(proto::AdvertiseContexts {
601                project_id,
602                contexts,
603            })
604            .ok();
605    }
606
607    fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
608        let Some(project_id) = self.project.read(cx).remote_id() else {
609            return;
610        };
611
612        let contexts = self
613            .contexts
614            .iter()
615            .filter_map(|context| {
616                let context = context.upgrade()?.read(cx);
617                if context.replica_id() != ReplicaId::default() {
618                    Some(context.version(cx).to_proto(context.id().clone()))
619                } else {
620                    None
621                }
622            })
623            .collect();
624
625        let client = self.client.clone();
626        let request = self.client.request(proto::SynchronizeContexts {
627            project_id,
628            contexts,
629        });
630        cx.spawn(|this, cx| async move {
631            let response = request.await?;
632
633            let mut context_ids = Vec::new();
634            let mut operations = Vec::new();
635            this.read_with(&cx, |this, cx| {
636                for context_version_proto in response.contexts {
637                    let context_version = ContextVersion::from_proto(&context_version_proto);
638                    let context_id = ContextId::from_proto(context_version_proto.context_id);
639                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
640                        context_ids.push(context_id);
641                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
642                    }
643                }
644            })?;
645
646            let operations = futures::future::join_all(operations).await;
647            for (context_id, operations) in context_ids.into_iter().zip(operations) {
648                for operation in operations {
649                    client.send(proto::UpdateContext {
650                        project_id,
651                        context_id: context_id.to_proto(),
652                        operation: Some(operation),
653                    })?;
654                }
655            }
656
657            anyhow::Ok(())
658        })
659        .detach_and_log_err(cx);
660    }
661
662    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
663        let metadata = self.contexts_metadata.clone();
664        let executor = cx.background_executor().clone();
665        cx.background_executor().spawn(async move {
666            if query.is_empty() {
667                metadata
668            } else {
669                let candidates = metadata
670                    .iter()
671                    .enumerate()
672                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
673                    .collect::<Vec<_>>();
674                let matches = fuzzy::match_strings(
675                    &candidates,
676                    &query,
677                    false,
678                    100,
679                    &Default::default(),
680                    executor,
681                )
682                .await;
683
684                matches
685                    .into_iter()
686                    .map(|mat| metadata[mat.candidate_id].clone())
687                    .collect()
688            }
689        })
690    }
691
692    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
693        &self.host_contexts
694    }
695
696    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
697        let fs = self.fs.clone();
698        cx.spawn(|this, mut cx| async move {
699            fs.create_dir(contexts_dir()).await?;
700
701            let mut paths = fs.read_dir(contexts_dir()).await?;
702            let mut contexts = Vec::<SavedContextMetadata>::new();
703            while let Some(path) = paths.next().await {
704                let path = path?;
705                if path.extension() != Some(OsStr::new("json")) {
706                    continue;
707                }
708
709                let pattern = r" - \d+.zed.json$";
710                let re = Regex::new(pattern).unwrap();
711
712                let metadata = fs.metadata(&path).await?;
713                if let Some((file_name, metadata)) = path
714                    .file_name()
715                    .and_then(|name| name.to_str())
716                    .zip(metadata)
717                {
718                    // This is used to filter out contexts saved by the new assistant.
719                    if !re.is_match(file_name) {
720                        continue;
721                    }
722
723                    if let Some(title) = re.replace(file_name, "").lines().next() {
724                        contexts.push(SavedContextMetadata {
725                            title: title.to_string(),
726                            path,
727                            mtime: metadata.mtime.into(),
728                        });
729                    }
730                }
731            }
732            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
733
734            this.update(&mut cx, |this, cx| {
735                this.contexts_metadata = contexts;
736                cx.notify();
737            })
738        })
739    }
740}