context_store.rs

  1use crate::{
  2    Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
  3    SavedContextMetadata,
  4};
  5use anyhow::{anyhow, Context as _, Result};
  6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
  7use clock::ReplicaId;
  8use fs::Fs;
  9use futures::StreamExt;
 10use fuzzy::StringMatchCandidate;
 11use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel};
 12use language::LanguageRegistry;
 13use paths::contexts_dir;
 14use project::Project;
 15use regex::Regex;
 16use std::{
 17    cmp::Reverse,
 18    ffi::OsStr,
 19    mem,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22    time::Duration,
 23};
 24use util::{ResultExt, TryFutureExt};
 25
 26pub fn init(client: &Arc<Client>) {
 27    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 28    client.add_model_request_handler(ContextStore::handle_open_context);
 29    client.add_model_message_handler(ContextStore::handle_update_context);
 30    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 31}
 32
 33#[derive(Clone)]
 34pub struct RemoteContextMetadata {
 35    pub id: ContextId,
 36    pub summary: Option<String>,
 37}
 38
 39pub struct ContextStore {
 40    contexts: Vec<ContextHandle>,
 41    contexts_metadata: Vec<SavedContextMetadata>,
 42    host_contexts: Vec<RemoteContextMetadata>,
 43    fs: Arc<dyn Fs>,
 44    languages: Arc<LanguageRegistry>,
 45    telemetry: Arc<Telemetry>,
 46    _watch_updates: Task<Option<()>>,
 47    client: Arc<Client>,
 48    project: Model<Project>,
 49    project_is_shared: bool,
 50    client_subscription: Option<client::Subscription>,
 51    _project_subscriptions: Vec<gpui::Subscription>,
 52}
 53
 54enum ContextHandle {
 55    Weak(WeakModel<Context>),
 56    Strong(Model<Context>),
 57}
 58
 59impl ContextHandle {
 60    fn upgrade(&self) -> Option<Model<Context>> {
 61        match self {
 62            ContextHandle::Weak(weak) => weak.upgrade(),
 63            ContextHandle::Strong(strong) => Some(strong.clone()),
 64        }
 65    }
 66
 67    fn downgrade(&self) -> WeakModel<Context> {
 68        match self {
 69            ContextHandle::Weak(weak) => weak.clone(),
 70            ContextHandle::Strong(strong) => strong.downgrade(),
 71        }
 72    }
 73}
 74
 75impl ContextStore {
 76    pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
 77        let fs = project.read(cx).fs().clone();
 78        let languages = project.read(cx).languages().clone();
 79        let telemetry = project.read(cx).client().telemetry().clone();
 80        cx.spawn(|mut cx| async move {
 81            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 82            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
 83
 84            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
 85                let mut this = Self {
 86                    contexts: Vec::new(),
 87                    contexts_metadata: Vec::new(),
 88                    host_contexts: Vec::new(),
 89                    fs,
 90                    languages,
 91                    telemetry,
 92                    _watch_updates: cx.spawn(|this, mut cx| {
 93                        async move {
 94                            while events.next().await.is_some() {
 95                                this.update(&mut cx, |this, cx| this.reload(cx))?
 96                                    .await
 97                                    .log_err();
 98                            }
 99                            anyhow::Ok(())
100                        }
101                        .log_err()
102                    }),
103                    client_subscription: None,
104                    _project_subscriptions: vec![
105                        cx.observe(&project, Self::handle_project_changed),
106                        cx.subscribe(&project, Self::handle_project_event),
107                    ],
108                    project_is_shared: false,
109                    client: project.read(cx).client(),
110                    project: project.clone(),
111                };
112                this.handle_project_changed(project, cx);
113                this.synchronize_contexts(cx);
114                this
115            })?;
116            this.update(&mut cx, |this, cx| this.reload(cx))?
117                .await
118                .log_err();
119            Ok(this)
120        })
121    }
122
123    async fn handle_advertise_contexts(
124        this: Model<Self>,
125        envelope: TypedEnvelope<proto::AdvertiseContexts>,
126        mut cx: AsyncAppContext,
127    ) -> Result<()> {
128        this.update(&mut cx, |this, cx| {
129            this.host_contexts = envelope
130                .payload
131                .contexts
132                .into_iter()
133                .map(|context| RemoteContextMetadata {
134                    id: ContextId::from_proto(context.context_id),
135                    summary: context.summary,
136                })
137                .collect();
138            cx.notify();
139        })
140    }
141
142    async fn handle_open_context(
143        this: Model<Self>,
144        envelope: TypedEnvelope<proto::OpenContext>,
145        mut cx: AsyncAppContext,
146    ) -> Result<proto::OpenContextResponse> {
147        let context_id = ContextId::from_proto(envelope.payload.context_id);
148        let operations = this.update(&mut cx, |this, cx| {
149            if this.project.read(cx).is_remote() {
150                return Err(anyhow!("only the host contexts can be opened"));
151            }
152
153            let context = this
154                .loaded_context_for_id(&context_id, cx)
155                .context("context not found")?;
156            if context.read(cx).replica_id() != ReplicaId::default() {
157                return Err(anyhow!("context must be opened via the host"));
158            }
159
160            anyhow::Ok(
161                context
162                    .read(cx)
163                    .serialize_ops(&ContextVersion::default(), cx),
164            )
165        })??;
166        let operations = operations.await;
167        Ok(proto::OpenContextResponse {
168            context: Some(proto::Context { operations }),
169        })
170    }
171
172    async fn handle_update_context(
173        this: Model<Self>,
174        envelope: TypedEnvelope<proto::UpdateContext>,
175        mut cx: AsyncAppContext,
176    ) -> Result<()> {
177        this.update(&mut cx, |this, cx| {
178            let context_id = ContextId::from_proto(envelope.payload.context_id);
179            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
180                let operation_proto = envelope.payload.operation.context("invalid operation")?;
181                let operation = ContextOperation::from_proto(operation_proto)?;
182                context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
183            }
184            Ok(())
185        })?
186    }
187
188    async fn handle_synchronize_contexts(
189        this: Model<Self>,
190        envelope: TypedEnvelope<proto::SynchronizeContexts>,
191        mut cx: AsyncAppContext,
192    ) -> Result<proto::SynchronizeContextsResponse> {
193        this.update(&mut cx, |this, cx| {
194            if this.project.read(cx).is_remote() {
195                return Err(anyhow!("only the host can synchronize contexts"));
196            }
197
198            let mut local_versions = Vec::new();
199            for remote_version_proto in envelope.payload.contexts {
200                let remote_version = ContextVersion::from_proto(&remote_version_proto);
201                let context_id = ContextId::from_proto(remote_version_proto.context_id);
202                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
203                    let context = context.read(cx);
204                    let operations = context.serialize_ops(&remote_version, cx);
205                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
206                    let client = this.client.clone();
207                    let project_id = envelope.payload.project_id;
208                    cx.background_executor()
209                        .spawn(async move {
210                            let operations = operations.await;
211                            for operation in operations {
212                                client.send(proto::UpdateContext {
213                                    project_id,
214                                    context_id: context_id.to_proto(),
215                                    operation: Some(operation),
216                                })?;
217                            }
218                            anyhow::Ok(())
219                        })
220                        .detach_and_log_err(cx);
221                }
222            }
223
224            this.advertise_contexts(cx);
225
226            anyhow::Ok(proto::SynchronizeContextsResponse {
227                contexts: local_versions,
228            })
229        })?
230    }
231
232    fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
233        let is_shared = self.project.read(cx).is_shared();
234        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
235        if is_shared == was_shared {
236            return;
237        }
238
239        if is_shared {
240            self.contexts.retain_mut(|context| {
241                if let Some(strong_context) = context.upgrade() {
242                    *context = ContextHandle::Strong(strong_context);
243                    true
244                } else {
245                    false
246                }
247            });
248            let remote_id = self.project.read(cx).remote_id().unwrap();
249            self.client_subscription = self
250                .client
251                .subscribe_to_entity(remote_id)
252                .log_err()
253                .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
254            self.advertise_contexts(cx);
255        } else {
256            self.client_subscription = None;
257        }
258    }
259
260    fn handle_project_event(
261        &mut self,
262        _: Model<Project>,
263        event: &project::Event,
264        cx: &mut ModelContext<Self>,
265    ) {
266        match event {
267            project::Event::Reshared => {
268                self.advertise_contexts(cx);
269            }
270            project::Event::HostReshared | project::Event::Rejoined => {
271                self.synchronize_contexts(cx);
272            }
273            project::Event::DisconnectedFromHost => {
274                self.contexts.retain_mut(|context| {
275                    if let Some(strong_context) = context.upgrade() {
276                        *context = ContextHandle::Weak(context.downgrade());
277                        strong_context.update(cx, |context, cx| {
278                            if context.replica_id() != ReplicaId::default() {
279                                context.set_capability(language::Capability::ReadOnly, cx);
280                            }
281                        });
282                        true
283                    } else {
284                        false
285                    }
286                });
287                self.host_contexts.clear();
288                cx.notify();
289            }
290            _ => {}
291        }
292    }
293
294    pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
295        let context = cx.new_model(|cx| {
296            Context::local(self.languages.clone(), Some(self.telemetry.clone()), cx)
297        });
298        self.register_context(&context, cx);
299        context
300    }
301
302    pub fn open_local_context(
303        &mut self,
304        path: PathBuf,
305        cx: &ModelContext<Self>,
306    ) -> Task<Result<Model<Context>>> {
307        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
308            return Task::ready(Ok(existing_context));
309        }
310
311        let fs = self.fs.clone();
312        let languages = self.languages.clone();
313        let telemetry = self.telemetry.clone();
314        let load = cx.background_executor().spawn({
315            let path = path.clone();
316            async move {
317                let saved_context = fs.load(&path).await?;
318                SavedContext::from_json(&saved_context)
319            }
320        });
321
322        cx.spawn(|this, mut cx| async move {
323            let saved_context = load.await?;
324            let context = cx.new_model(|cx| {
325                Context::deserialize(saved_context, path.clone(), languages, Some(telemetry), cx)
326            })?;
327            this.update(&mut cx, |this, cx| {
328                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
329                    existing_context
330                } else {
331                    this.register_context(&context, cx);
332                    context
333                }
334            })
335        })
336    }
337
338    fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
339        self.contexts.iter().find_map(|context| {
340            let context = context.upgrade()?;
341            if context.read(cx).path() == Some(path) {
342                Some(context)
343            } else {
344                None
345            }
346        })
347    }
348
349    fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option<Model<Context>> {
350        self.contexts.iter().find_map(|context| {
351            let context = context.upgrade()?;
352            if context.read(cx).id() == id {
353                Some(context)
354            } else {
355                None
356            }
357        })
358    }
359
360    pub fn open_remote_context(
361        &mut self,
362        context_id: ContextId,
363        cx: &mut ModelContext<Self>,
364    ) -> Task<Result<Model<Context>>> {
365        let project = self.project.read(cx);
366        let Some(project_id) = project.remote_id() else {
367            return Task::ready(Err(anyhow!("project was not remote")));
368        };
369        if project.is_local() {
370            return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
371        }
372
373        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
374            return Task::ready(Ok(context));
375        }
376
377        let replica_id = project.replica_id();
378        let capability = project.capability();
379        let language_registry = self.languages.clone();
380        let telemetry = self.telemetry.clone();
381        let request = self.client.request(proto::OpenContext {
382            project_id,
383            context_id: context_id.to_proto(),
384        });
385        cx.spawn(|this, mut cx| async move {
386            let response = request.await?;
387            let context_proto = response.context.context("invalid context")?;
388            let context = cx.new_model(|cx| {
389                Context::new(
390                    context_id.clone(),
391                    replica_id,
392                    capability,
393                    language_registry,
394                    Some(telemetry),
395                    cx,
396                )
397            })?;
398            let operations = cx
399                .background_executor()
400                .spawn(async move {
401                    context_proto
402                        .operations
403                        .into_iter()
404                        .map(|op| ContextOperation::from_proto(op))
405                        .collect::<Result<Vec<_>>>()
406                })
407                .await?;
408            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
409            this.update(&mut cx, |this, cx| {
410                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
411                    existing_context
412                } else {
413                    this.register_context(&context, cx);
414                    this.synchronize_contexts(cx);
415                    context
416                }
417            })
418        })
419    }
420
421    fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
422        let handle = if self.project_is_shared {
423            ContextHandle::Strong(context.clone())
424        } else {
425            ContextHandle::Weak(context.downgrade())
426        };
427        self.contexts.push(handle);
428        self.advertise_contexts(cx);
429        cx.subscribe(context, Self::handle_context_event).detach();
430    }
431
432    fn handle_context_event(
433        &mut self,
434        context: Model<Context>,
435        event: &ContextEvent,
436        cx: &mut ModelContext<Self>,
437    ) {
438        let Some(project_id) = self.project.read(cx).remote_id() else {
439            return;
440        };
441
442        match event {
443            ContextEvent::SummaryChanged => {
444                self.advertise_contexts(cx);
445            }
446            ContextEvent::Operation(operation) => {
447                let context_id = context.read(cx).id().to_proto();
448                let operation = operation.to_proto();
449                self.client
450                    .send(proto::UpdateContext {
451                        project_id,
452                        context_id,
453                        operation: Some(operation),
454                    })
455                    .log_err();
456            }
457            _ => {}
458        }
459    }
460
461    fn advertise_contexts(&self, cx: &AppContext) {
462        let Some(project_id) = self.project.read(cx).remote_id() else {
463            return;
464        };
465
466        // For now, only the host can advertise their open contexts.
467        if self.project.read(cx).is_remote() {
468            return;
469        }
470
471        let contexts = self
472            .contexts
473            .iter()
474            .rev()
475            .filter_map(|context| {
476                let context = context.upgrade()?.read(cx);
477                if context.replica_id() == ReplicaId::default() {
478                    Some(proto::ContextMetadata {
479                        context_id: context.id().to_proto(),
480                        summary: context.summary().map(|summary| summary.text.clone()),
481                    })
482                } else {
483                    None
484                }
485            })
486            .collect();
487        self.client
488            .send(proto::AdvertiseContexts {
489                project_id,
490                contexts,
491            })
492            .ok();
493    }
494
495    fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
496        let Some(project_id) = self.project.read(cx).remote_id() else {
497            return;
498        };
499
500        let contexts = self
501            .contexts
502            .iter()
503            .filter_map(|context| {
504                let context = context.upgrade()?.read(cx);
505                if context.replica_id() != ReplicaId::default() {
506                    Some(context.version(cx).to_proto(context.id().clone()))
507                } else {
508                    None
509                }
510            })
511            .collect();
512
513        let client = self.client.clone();
514        let request = self.client.request(proto::SynchronizeContexts {
515            project_id,
516            contexts,
517        });
518        cx.spawn(|this, cx| async move {
519            let response = request.await?;
520
521            let mut context_ids = Vec::new();
522            let mut operations = Vec::new();
523            this.read_with(&cx, |this, cx| {
524                for context_version_proto in response.contexts {
525                    let context_version = ContextVersion::from_proto(&context_version_proto);
526                    let context_id = ContextId::from_proto(context_version_proto.context_id);
527                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
528                        context_ids.push(context_id);
529                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
530                    }
531                }
532            })?;
533
534            let operations = futures::future::join_all(operations).await;
535            for (context_id, operations) in context_ids.into_iter().zip(operations) {
536                for operation in operations {
537                    client.send(proto::UpdateContext {
538                        project_id,
539                        context_id: context_id.to_proto(),
540                        operation: Some(operation),
541                    })?;
542                }
543            }
544
545            anyhow::Ok(())
546        })
547        .detach_and_log_err(cx);
548    }
549
550    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
551        let metadata = self.contexts_metadata.clone();
552        let executor = cx.background_executor().clone();
553        cx.background_executor().spawn(async move {
554            if query.is_empty() {
555                metadata
556            } else {
557                let candidates = metadata
558                    .iter()
559                    .enumerate()
560                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
561                    .collect::<Vec<_>>();
562                let matches = fuzzy::match_strings(
563                    &candidates,
564                    &query,
565                    false,
566                    100,
567                    &Default::default(),
568                    executor,
569                )
570                .await;
571
572                matches
573                    .into_iter()
574                    .map(|mat| metadata[mat.candidate_id].clone())
575                    .collect()
576            }
577        })
578    }
579
580    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
581        &self.host_contexts
582    }
583
584    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
585        let fs = self.fs.clone();
586        cx.spawn(|this, mut cx| async move {
587            fs.create_dir(contexts_dir()).await?;
588
589            let mut paths = fs.read_dir(contexts_dir()).await?;
590            let mut contexts = Vec::<SavedContextMetadata>::new();
591            while let Some(path) = paths.next().await {
592                let path = path?;
593                if path.extension() != Some(OsStr::new("json")) {
594                    continue;
595                }
596
597                let pattern = r" - \d+.zed.json$";
598                let re = Regex::new(pattern).unwrap();
599
600                let metadata = fs.metadata(&path).await?;
601                if let Some((file_name, metadata)) = path
602                    .file_name()
603                    .and_then(|name| name.to_str())
604                    .zip(metadata)
605                {
606                    // This is used to filter out contexts saved by the new assistant.
607                    if !re.is_match(file_name) {
608                        continue;
609                    }
610
611                    if let Some(title) = re.replace(file_name, "").lines().next() {
612                        contexts.push(SavedContextMetadata {
613                            title: title.to_string(),
614                            path,
615                            mtime: metadata.mtime.into(),
616                        });
617                    }
618                }
619            }
620            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
621
622            this.update(&mut cx, |this, cx| {
623                this.contexts_metadata = contexts;
624                cx.notify();
625            })
626        })
627    }
628}