context_store.rs

  1use crate::{
  2    prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion,
  3    SavedContext, SavedContextMetadata,
  4};
  5use anyhow::{anyhow, Context as _, Result};
  6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
  7use clock::ReplicaId;
  8use fs::Fs;
  9use futures::StreamExt;
 10use fuzzy::StringMatchCandidate;
 11use gpui::{
 12    AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
 13};
 14use language::LanguageRegistry;
 15use paths::contexts_dir;
 16use project::Project;
 17use regex::Regex;
 18use std::{
 19    cmp::Reverse,
 20    ffi::OsStr,
 21    mem,
 22    path::{Path, PathBuf},
 23    sync::Arc,
 24    time::Duration,
 25};
 26use util::{ResultExt, TryFutureExt};
 27
 28pub fn init(client: &Arc<Client>) {
 29    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 30    client.add_model_request_handler(ContextStore::handle_open_context);
 31    client.add_model_request_handler(ContextStore::handle_create_context);
 32    client.add_model_message_handler(ContextStore::handle_update_context);
 33    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 34}
 35
 36#[derive(Clone)]
 37pub struct RemoteContextMetadata {
 38    pub id: ContextId,
 39    pub summary: Option<String>,
 40}
 41
 42pub struct ContextStore {
 43    contexts: Vec<ContextHandle>,
 44    contexts_metadata: Vec<SavedContextMetadata>,
 45    host_contexts: Vec<RemoteContextMetadata>,
 46    fs: Arc<dyn Fs>,
 47    languages: Arc<LanguageRegistry>,
 48    telemetry: Arc<Telemetry>,
 49    _watch_updates: Task<Option<()>>,
 50    client: Arc<Client>,
 51    project: Model<Project>,
 52    project_is_shared: bool,
 53    client_subscription: Option<client::Subscription>,
 54    _project_subscriptions: Vec<gpui::Subscription>,
 55    prompt_builder: Arc<PromptBuilder>,
 56}
 57
 58pub enum ContextStoreEvent {
 59    ContextCreated(ContextId),
 60}
 61
 62impl EventEmitter<ContextStoreEvent> for ContextStore {}
 63
 64enum ContextHandle {
 65    Weak(WeakModel<Context>),
 66    Strong(Model<Context>),
 67}
 68
 69impl ContextHandle {
 70    fn upgrade(&self) -> Option<Model<Context>> {
 71        match self {
 72            ContextHandle::Weak(weak) => weak.upgrade(),
 73            ContextHandle::Strong(strong) => Some(strong.clone()),
 74        }
 75    }
 76
 77    fn downgrade(&self) -> WeakModel<Context> {
 78        match self {
 79            ContextHandle::Weak(weak) => weak.clone(),
 80            ContextHandle::Strong(strong) => strong.downgrade(),
 81        }
 82    }
 83}
 84
 85impl ContextStore {
 86    pub fn new(
 87        project: Model<Project>,
 88        prompt_builder: Arc<PromptBuilder>,
 89        cx: &mut AppContext,
 90    ) -> Task<Result<Model<Self>>> {
 91        let fs = project.read(cx).fs().clone();
 92        let languages = project.read(cx).languages().clone();
 93        let telemetry = project.read(cx).client().telemetry().clone();
 94        cx.spawn(|mut cx| async move {
 95            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 96            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
 97
 98            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
 99                let mut this = Self {
100                    contexts: Vec::new(),
101                    contexts_metadata: Vec::new(),
102                    host_contexts: Vec::new(),
103                    fs,
104                    languages,
105                    telemetry,
106                    _watch_updates: cx.spawn(|this, mut cx| {
107                        async move {
108                            while events.next().await.is_some() {
109                                this.update(&mut cx, |this, cx| this.reload(cx))?
110                                    .await
111                                    .log_err();
112                            }
113                            anyhow::Ok(())
114                        }
115                        .log_err()
116                    }),
117                    client_subscription: None,
118                    _project_subscriptions: vec![
119                        cx.observe(&project, Self::handle_project_changed),
120                        cx.subscribe(&project, Self::handle_project_event),
121                    ],
122                    project_is_shared: false,
123                    client: project.read(cx).client(),
124                    project: project.clone(),
125                    prompt_builder,
126                };
127                this.handle_project_changed(project, cx);
128                this.synchronize_contexts(cx);
129                this
130            })?;
131            this.update(&mut cx, |this, cx| this.reload(cx))?
132                .await
133                .log_err();
134            Ok(this)
135        })
136    }
137
138    async fn handle_advertise_contexts(
139        this: Model<Self>,
140        envelope: TypedEnvelope<proto::AdvertiseContexts>,
141        mut cx: AsyncAppContext,
142    ) -> Result<()> {
143        this.update(&mut cx, |this, cx| {
144            this.host_contexts = envelope
145                .payload
146                .contexts
147                .into_iter()
148                .map(|context| RemoteContextMetadata {
149                    id: ContextId::from_proto(context.context_id),
150                    summary: context.summary,
151                })
152                .collect();
153            cx.notify();
154        })
155    }
156
157    async fn handle_open_context(
158        this: Model<Self>,
159        envelope: TypedEnvelope<proto::OpenContext>,
160        mut cx: AsyncAppContext,
161    ) -> Result<proto::OpenContextResponse> {
162        let context_id = ContextId::from_proto(envelope.payload.context_id);
163        let operations = this.update(&mut cx, |this, cx| {
164            if this.project.read(cx).is_remote() {
165                return Err(anyhow!("only the host contexts can be opened"));
166            }
167
168            let context = this
169                .loaded_context_for_id(&context_id, cx)
170                .context("context not found")?;
171            if context.read(cx).replica_id() != ReplicaId::default() {
172                return Err(anyhow!("context must be opened via the host"));
173            }
174
175            anyhow::Ok(
176                context
177                    .read(cx)
178                    .serialize_ops(&ContextVersion::default(), cx),
179            )
180        })??;
181        let operations = operations.await;
182        Ok(proto::OpenContextResponse {
183            context: Some(proto::Context { operations }),
184        })
185    }
186
187    async fn handle_create_context(
188        this: Model<Self>,
189        _: TypedEnvelope<proto::CreateContext>,
190        mut cx: AsyncAppContext,
191    ) -> Result<proto::CreateContextResponse> {
192        let (context_id, operations) = this.update(&mut cx, |this, cx| {
193            if this.project.read(cx).is_remote() {
194                return Err(anyhow!("can only create contexts as the host"));
195            }
196
197            let context = this.create(cx);
198            let context_id = context.read(cx).id().clone();
199            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
200
201            anyhow::Ok((
202                context_id,
203                context
204                    .read(cx)
205                    .serialize_ops(&ContextVersion::default(), cx),
206            ))
207        })??;
208        let operations = operations.await;
209        Ok(proto::CreateContextResponse {
210            context_id: context_id.to_proto(),
211            context: Some(proto::Context { operations }),
212        })
213    }
214
215    async fn handle_update_context(
216        this: Model<Self>,
217        envelope: TypedEnvelope<proto::UpdateContext>,
218        mut cx: AsyncAppContext,
219    ) -> Result<()> {
220        this.update(&mut cx, |this, cx| {
221            let context_id = ContextId::from_proto(envelope.payload.context_id);
222            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
223                let operation_proto = envelope.payload.operation.context("invalid operation")?;
224                let operation = ContextOperation::from_proto(operation_proto)?;
225                context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
226            }
227            Ok(())
228        })?
229    }
230
231    async fn handle_synchronize_contexts(
232        this: Model<Self>,
233        envelope: TypedEnvelope<proto::SynchronizeContexts>,
234        mut cx: AsyncAppContext,
235    ) -> Result<proto::SynchronizeContextsResponse> {
236        this.update(&mut cx, |this, cx| {
237            if this.project.read(cx).is_remote() {
238                return Err(anyhow!("only the host can synchronize contexts"));
239            }
240
241            let mut local_versions = Vec::new();
242            for remote_version_proto in envelope.payload.contexts {
243                let remote_version = ContextVersion::from_proto(&remote_version_proto);
244                let context_id = ContextId::from_proto(remote_version_proto.context_id);
245                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
246                    let context = context.read(cx);
247                    let operations = context.serialize_ops(&remote_version, cx);
248                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
249                    let client = this.client.clone();
250                    let project_id = envelope.payload.project_id;
251                    cx.background_executor()
252                        .spawn(async move {
253                            let operations = operations.await;
254                            for operation in operations {
255                                client.send(proto::UpdateContext {
256                                    project_id,
257                                    context_id: context_id.to_proto(),
258                                    operation: Some(operation),
259                                })?;
260                            }
261                            anyhow::Ok(())
262                        })
263                        .detach_and_log_err(cx);
264                }
265            }
266
267            this.advertise_contexts(cx);
268
269            anyhow::Ok(proto::SynchronizeContextsResponse {
270                contexts: local_versions,
271            })
272        })?
273    }
274
275    fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
276        let is_shared = self.project.read(cx).is_shared();
277        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
278        if is_shared == was_shared {
279            return;
280        }
281
282        if is_shared {
283            self.contexts.retain_mut(|context| {
284                if let Some(strong_context) = context.upgrade() {
285                    *context = ContextHandle::Strong(strong_context);
286                    true
287                } else {
288                    false
289                }
290            });
291            let remote_id = self.project.read(cx).remote_id().unwrap();
292            self.client_subscription = self
293                .client
294                .subscribe_to_entity(remote_id)
295                .log_err()
296                .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
297            self.advertise_contexts(cx);
298        } else {
299            self.client_subscription = None;
300        }
301    }
302
303    fn handle_project_event(
304        &mut self,
305        _: Model<Project>,
306        event: &project::Event,
307        cx: &mut ModelContext<Self>,
308    ) {
309        match event {
310            project::Event::Reshared => {
311                self.advertise_contexts(cx);
312            }
313            project::Event::HostReshared | project::Event::Rejoined => {
314                self.synchronize_contexts(cx);
315            }
316            project::Event::DisconnectedFromHost => {
317                self.contexts.retain_mut(|context| {
318                    if let Some(strong_context) = context.upgrade() {
319                        *context = ContextHandle::Weak(context.downgrade());
320                        strong_context.update(cx, |context, cx| {
321                            if context.replica_id() != ReplicaId::default() {
322                                context.set_capability(language::Capability::ReadOnly, cx);
323                            }
324                        });
325                        true
326                    } else {
327                        false
328                    }
329                });
330                self.host_contexts.clear();
331                cx.notify();
332            }
333            _ => {}
334        }
335    }
336
337    pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
338        let context = cx.new_model(|cx| {
339            Context::local(
340                self.languages.clone(),
341                Some(self.project.clone()),
342                Some(self.telemetry.clone()),
343                self.prompt_builder.clone(),
344                cx,
345            )
346        });
347        self.register_context(&context, cx);
348        context
349    }
350
351    pub fn create_remote_context(
352        &mut self,
353        cx: &mut ModelContext<Self>,
354    ) -> Task<Result<Model<Context>>> {
355        let project = self.project.read(cx);
356        let Some(project_id) = project.remote_id() else {
357            return Task::ready(Err(anyhow!("project was not remote")));
358        };
359        if project.is_local() {
360            return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
361        }
362
363        let replica_id = project.replica_id();
364        let capability = project.capability();
365        let language_registry = self.languages.clone();
366        let project = self.project.clone();
367        let telemetry = self.telemetry.clone();
368        let prompt_builder = self.prompt_builder.clone();
369        let request = self.client.request(proto::CreateContext { project_id });
370        cx.spawn(|this, mut cx| async move {
371            let response = request.await?;
372            let context_id = ContextId::from_proto(response.context_id);
373            let context_proto = response.context.context("invalid context")?;
374            let context = cx.new_model(|cx| {
375                Context::new(
376                    context_id.clone(),
377                    replica_id,
378                    capability,
379                    language_registry,
380                    prompt_builder,
381                    Some(project),
382                    Some(telemetry),
383                    cx,
384                )
385            })?;
386            let operations = cx
387                .background_executor()
388                .spawn(async move {
389                    context_proto
390                        .operations
391                        .into_iter()
392                        .map(|op| ContextOperation::from_proto(op))
393                        .collect::<Result<Vec<_>>>()
394                })
395                .await?;
396            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
397            this.update(&mut cx, |this, cx| {
398                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
399                    existing_context
400                } else {
401                    this.register_context(&context, cx);
402                    this.synchronize_contexts(cx);
403                    context
404                }
405            })
406        })
407    }
408
409    pub fn open_local_context(
410        &mut self,
411        path: PathBuf,
412        cx: &ModelContext<Self>,
413    ) -> Task<Result<Model<Context>>> {
414        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
415            return Task::ready(Ok(existing_context));
416        }
417
418        let fs = self.fs.clone();
419        let languages = self.languages.clone();
420        let project = self.project.clone();
421        let telemetry = self.telemetry.clone();
422        let load = cx.background_executor().spawn({
423            let path = path.clone();
424            async move {
425                let saved_context = fs.load(&path).await?;
426                SavedContext::from_json(&saved_context)
427            }
428        });
429        let prompt_builder = self.prompt_builder.clone();
430
431        cx.spawn(|this, mut cx| async move {
432            let saved_context = load.await?;
433            let context = cx.new_model(|cx| {
434                Context::deserialize(
435                    saved_context,
436                    path.clone(),
437                    languages,
438                    prompt_builder,
439                    Some(project),
440                    Some(telemetry),
441                    cx,
442                )
443            })?;
444            this.update(&mut cx, |this, cx| {
445                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
446                    existing_context
447                } else {
448                    this.register_context(&context, cx);
449                    context
450                }
451            })
452        })
453    }
454
455    fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
456        self.contexts.iter().find_map(|context| {
457            let context = context.upgrade()?;
458            if context.read(cx).path() == Some(path) {
459                Some(context)
460            } else {
461                None
462            }
463        })
464    }
465
466    pub(super) fn loaded_context_for_id(
467        &self,
468        id: &ContextId,
469        cx: &AppContext,
470    ) -> Option<Model<Context>> {
471        self.contexts.iter().find_map(|context| {
472            let context = context.upgrade()?;
473            if context.read(cx).id() == id {
474                Some(context)
475            } else {
476                None
477            }
478        })
479    }
480
481    pub fn open_remote_context(
482        &mut self,
483        context_id: ContextId,
484        cx: &mut ModelContext<Self>,
485    ) -> Task<Result<Model<Context>>> {
486        let project = self.project.read(cx);
487        let Some(project_id) = project.remote_id() else {
488            return Task::ready(Err(anyhow!("project was not remote")));
489        };
490        if project.is_local() {
491            return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
492        }
493
494        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
495            return Task::ready(Ok(context));
496        }
497
498        let replica_id = project.replica_id();
499        let capability = project.capability();
500        let language_registry = self.languages.clone();
501        let project = self.project.clone();
502        let telemetry = self.telemetry.clone();
503        let request = self.client.request(proto::OpenContext {
504            project_id,
505            context_id: context_id.to_proto(),
506        });
507        let prompt_builder = self.prompt_builder.clone();
508        cx.spawn(|this, mut cx| async move {
509            let response = request.await?;
510            let context_proto = response.context.context("invalid context")?;
511            let context = cx.new_model(|cx| {
512                Context::new(
513                    context_id.clone(),
514                    replica_id,
515                    capability,
516                    language_registry,
517                    prompt_builder,
518                    Some(project),
519                    Some(telemetry),
520                    cx,
521                )
522            })?;
523            let operations = cx
524                .background_executor()
525                .spawn(async move {
526                    context_proto
527                        .operations
528                        .into_iter()
529                        .map(|op| ContextOperation::from_proto(op))
530                        .collect::<Result<Vec<_>>>()
531                })
532                .await?;
533            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
534            this.update(&mut cx, |this, cx| {
535                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
536                    existing_context
537                } else {
538                    this.register_context(&context, cx);
539                    this.synchronize_contexts(cx);
540                    context
541                }
542            })
543        })
544    }
545
546    fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
547        let handle = if self.project_is_shared {
548            ContextHandle::Strong(context.clone())
549        } else {
550            ContextHandle::Weak(context.downgrade())
551        };
552        self.contexts.push(handle);
553        self.advertise_contexts(cx);
554        cx.subscribe(context, Self::handle_context_event).detach();
555    }
556
557    fn handle_context_event(
558        &mut self,
559        context: Model<Context>,
560        event: &ContextEvent,
561        cx: &mut ModelContext<Self>,
562    ) {
563        let Some(project_id) = self.project.read(cx).remote_id() else {
564            return;
565        };
566
567        match event {
568            ContextEvent::SummaryChanged => {
569                self.advertise_contexts(cx);
570            }
571            ContextEvent::Operation(operation) => {
572                let context_id = context.read(cx).id().to_proto();
573                let operation = operation.to_proto();
574                self.client
575                    .send(proto::UpdateContext {
576                        project_id,
577                        context_id,
578                        operation: Some(operation),
579                    })
580                    .log_err();
581            }
582            _ => {}
583        }
584    }
585
586    fn advertise_contexts(&self, cx: &AppContext) {
587        let Some(project_id) = self.project.read(cx).remote_id() else {
588            return;
589        };
590
591        // For now, only the host can advertise their open contexts.
592        if self.project.read(cx).is_remote() {
593            return;
594        }
595
596        let contexts = self
597            .contexts
598            .iter()
599            .rev()
600            .filter_map(|context| {
601                let context = context.upgrade()?.read(cx);
602                if context.replica_id() == ReplicaId::default() {
603                    Some(proto::ContextMetadata {
604                        context_id: context.id().to_proto(),
605                        summary: context.summary().map(|summary| summary.text.clone()),
606                    })
607                } else {
608                    None
609                }
610            })
611            .collect();
612        self.client
613            .send(proto::AdvertiseContexts {
614                project_id,
615                contexts,
616            })
617            .ok();
618    }
619
620    fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
621        let Some(project_id) = self.project.read(cx).remote_id() else {
622            return;
623        };
624
625        let contexts = self
626            .contexts
627            .iter()
628            .filter_map(|context| {
629                let context = context.upgrade()?.read(cx);
630                if context.replica_id() != ReplicaId::default() {
631                    Some(context.version(cx).to_proto(context.id().clone()))
632                } else {
633                    None
634                }
635            })
636            .collect();
637
638        let client = self.client.clone();
639        let request = self.client.request(proto::SynchronizeContexts {
640            project_id,
641            contexts,
642        });
643        cx.spawn(|this, cx| async move {
644            let response = request.await?;
645
646            let mut context_ids = Vec::new();
647            let mut operations = Vec::new();
648            this.read_with(&cx, |this, cx| {
649                for context_version_proto in response.contexts {
650                    let context_version = ContextVersion::from_proto(&context_version_proto);
651                    let context_id = ContextId::from_proto(context_version_proto.context_id);
652                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
653                        context_ids.push(context_id);
654                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
655                    }
656                }
657            })?;
658
659            let operations = futures::future::join_all(operations).await;
660            for (context_id, operations) in context_ids.into_iter().zip(operations) {
661                for operation in operations {
662                    client.send(proto::UpdateContext {
663                        project_id,
664                        context_id: context_id.to_proto(),
665                        operation: Some(operation),
666                    })?;
667                }
668            }
669
670            anyhow::Ok(())
671        })
672        .detach_and_log_err(cx);
673    }
674
675    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
676        let metadata = self.contexts_metadata.clone();
677        let executor = cx.background_executor().clone();
678        cx.background_executor().spawn(async move {
679            if query.is_empty() {
680                metadata
681            } else {
682                let candidates = metadata
683                    .iter()
684                    .enumerate()
685                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
686                    .collect::<Vec<_>>();
687                let matches = fuzzy::match_strings(
688                    &candidates,
689                    &query,
690                    false,
691                    100,
692                    &Default::default(),
693                    executor,
694                )
695                .await;
696
697                matches
698                    .into_iter()
699                    .map(|mat| metadata[mat.candidate_id].clone())
700                    .collect()
701            }
702        })
703    }
704
705    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
706        &self.host_contexts
707    }
708
709    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
710        let fs = self.fs.clone();
711        cx.spawn(|this, mut cx| async move {
712            fs.create_dir(contexts_dir()).await?;
713
714            let mut paths = fs.read_dir(contexts_dir()).await?;
715            let mut contexts = Vec::<SavedContextMetadata>::new();
716            while let Some(path) = paths.next().await {
717                let path = path?;
718                if path.extension() != Some(OsStr::new("json")) {
719                    continue;
720                }
721
722                let pattern = r" - \d+.zed.json$";
723                let re = Regex::new(pattern).unwrap();
724
725                let metadata = fs.metadata(&path).await?;
726                if let Some((file_name, metadata)) = path
727                    .file_name()
728                    .and_then(|name| name.to_str())
729                    .zip(metadata)
730                {
731                    // This is used to filter out contexts saved by the new assistant.
732                    if !re.is_match(file_name) {
733                        continue;
734                    }
735
736                    if let Some(title) = re.replace(file_name, "").lines().next() {
737                        contexts.push(SavedContextMetadata {
738                            title: title.to_string(),
739                            path,
740                            mtime: metadata.mtime.into(),
741                        });
742                    }
743                }
744            }
745            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
746
747            this.update(&mut cx, |this, cx| {
748                this.contexts_metadata = contexts;
749                cx.notify();
750            })
751        })
752    }
753}