context_store.rs

  1use crate::{
  2    prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion,
  3    SavedContext, SavedContextMetadata,
  4};
  5use anyhow::{anyhow, Context as _, Result};
  6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
  7use clock::ReplicaId;
  8use fs::Fs;
  9use futures::StreamExt;
 10use fuzzy::StringMatchCandidate;
 11use gpui::{
 12    AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
 13};
 14use language::LanguageRegistry;
 15use paths::contexts_dir;
 16use project::Project;
 17use regex::Regex;
 18use rpc::AnyProtoClient;
 19use std::{
 20    cmp::Reverse,
 21    ffi::OsStr,
 22    mem,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25    time::Duration,
 26};
 27use util::{ResultExt, TryFutureExt};
 28
 29pub fn init(client: &AnyProtoClient) {
 30    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 31    client.add_model_request_handler(ContextStore::handle_open_context);
 32    client.add_model_request_handler(ContextStore::handle_create_context);
 33    client.add_model_message_handler(ContextStore::handle_update_context);
 34    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 35}
 36
 37#[derive(Clone)]
 38pub struct RemoteContextMetadata {
 39    pub id: ContextId,
 40    pub summary: Option<String>,
 41}
 42
 43pub struct ContextStore {
 44    contexts: Vec<ContextHandle>,
 45    contexts_metadata: Vec<SavedContextMetadata>,
 46    host_contexts: Vec<RemoteContextMetadata>,
 47    fs: Arc<dyn Fs>,
 48    languages: Arc<LanguageRegistry>,
 49    telemetry: Arc<Telemetry>,
 50    _watch_updates: Task<Option<()>>,
 51    client: Arc<Client>,
 52    project: Model<Project>,
 53    project_is_shared: bool,
 54    client_subscription: Option<client::Subscription>,
 55    _project_subscriptions: Vec<gpui::Subscription>,
 56    prompt_builder: Arc<PromptBuilder>,
 57}
 58
 59pub enum ContextStoreEvent {
 60    ContextCreated(ContextId),
 61}
 62
 63impl EventEmitter<ContextStoreEvent> for ContextStore {}
 64
 65enum ContextHandle {
 66    Weak(WeakModel<Context>),
 67    Strong(Model<Context>),
 68}
 69
 70impl ContextHandle {
 71    fn upgrade(&self) -> Option<Model<Context>> {
 72        match self {
 73            ContextHandle::Weak(weak) => weak.upgrade(),
 74            ContextHandle::Strong(strong) => Some(strong.clone()),
 75        }
 76    }
 77
 78    fn downgrade(&self) -> WeakModel<Context> {
 79        match self {
 80            ContextHandle::Weak(weak) => weak.clone(),
 81            ContextHandle::Strong(strong) => strong.downgrade(),
 82        }
 83    }
 84}
 85
 86impl ContextStore {
 87    pub fn new(
 88        project: Model<Project>,
 89        prompt_builder: Arc<PromptBuilder>,
 90        cx: &mut AppContext,
 91    ) -> Task<Result<Model<Self>>> {
 92        let fs = project.read(cx).fs().clone();
 93        let languages = project.read(cx).languages().clone();
 94        let telemetry = project.read(cx).client().telemetry().clone();
 95        cx.spawn(|mut cx| async move {
 96            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 97            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
 98
 99            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
100                let mut this = Self {
101                    contexts: Vec::new(),
102                    contexts_metadata: Vec::new(),
103                    host_contexts: Vec::new(),
104                    fs,
105                    languages,
106                    telemetry,
107                    _watch_updates: cx.spawn(|this, mut cx| {
108                        async move {
109                            while events.next().await.is_some() {
110                                this.update(&mut cx, |this, cx| this.reload(cx))?
111                                    .await
112                                    .log_err();
113                            }
114                            anyhow::Ok(())
115                        }
116                        .log_err()
117                    }),
118                    client_subscription: None,
119                    _project_subscriptions: vec![
120                        cx.observe(&project, Self::handle_project_changed),
121                        cx.subscribe(&project, Self::handle_project_event),
122                    ],
123                    project_is_shared: false,
124                    client: project.read(cx).client(),
125                    project: project.clone(),
126                    prompt_builder,
127                };
128                this.handle_project_changed(project, cx);
129                this.synchronize_contexts(cx);
130                this
131            })?;
132            this.update(&mut cx, |this, cx| this.reload(cx))?
133                .await
134                .log_err();
135            Ok(this)
136        })
137    }
138
139    async fn handle_advertise_contexts(
140        this: Model<Self>,
141        envelope: TypedEnvelope<proto::AdvertiseContexts>,
142        mut cx: AsyncAppContext,
143    ) -> Result<()> {
144        this.update(&mut cx, |this, cx| {
145            this.host_contexts = envelope
146                .payload
147                .contexts
148                .into_iter()
149                .map(|context| RemoteContextMetadata {
150                    id: ContextId::from_proto(context.context_id),
151                    summary: context.summary,
152                })
153                .collect();
154            cx.notify();
155        })
156    }
157
158    async fn handle_open_context(
159        this: Model<Self>,
160        envelope: TypedEnvelope<proto::OpenContext>,
161        mut cx: AsyncAppContext,
162    ) -> Result<proto::OpenContextResponse> {
163        let context_id = ContextId::from_proto(envelope.payload.context_id);
164        let operations = this.update(&mut cx, |this, cx| {
165            if this.project.read(cx).is_via_collab() {
166                return Err(anyhow!("only the host contexts can be opened"));
167            }
168
169            let context = this
170                .loaded_context_for_id(&context_id, cx)
171                .context("context not found")?;
172            if context.read(cx).replica_id() != ReplicaId::default() {
173                return Err(anyhow!("context must be opened via the host"));
174            }
175
176            anyhow::Ok(
177                context
178                    .read(cx)
179                    .serialize_ops(&ContextVersion::default(), cx),
180            )
181        })??;
182        let operations = operations.await;
183        Ok(proto::OpenContextResponse {
184            context: Some(proto::Context { operations }),
185        })
186    }
187
188    async fn handle_create_context(
189        this: Model<Self>,
190        _: TypedEnvelope<proto::CreateContext>,
191        mut cx: AsyncAppContext,
192    ) -> Result<proto::CreateContextResponse> {
193        let (context_id, operations) = this.update(&mut cx, |this, cx| {
194            if this.project.read(cx).is_via_collab() {
195                return Err(anyhow!("can only create contexts as the host"));
196            }
197
198            let context = this.create(cx);
199            let context_id = context.read(cx).id().clone();
200            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
201
202            anyhow::Ok((
203                context_id,
204                context
205                    .read(cx)
206                    .serialize_ops(&ContextVersion::default(), cx),
207            ))
208        })??;
209        let operations = operations.await;
210        Ok(proto::CreateContextResponse {
211            context_id: context_id.to_proto(),
212            context: Some(proto::Context { operations }),
213        })
214    }
215
216    async fn handle_update_context(
217        this: Model<Self>,
218        envelope: TypedEnvelope<proto::UpdateContext>,
219        mut cx: AsyncAppContext,
220    ) -> Result<()> {
221        this.update(&mut cx, |this, cx| {
222            let context_id = ContextId::from_proto(envelope.payload.context_id);
223            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
224                let operation_proto = envelope.payload.operation.context("invalid operation")?;
225                let operation = ContextOperation::from_proto(operation_proto)?;
226                context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
227            }
228            Ok(())
229        })?
230    }
231
232    async fn handle_synchronize_contexts(
233        this: Model<Self>,
234        envelope: TypedEnvelope<proto::SynchronizeContexts>,
235        mut cx: AsyncAppContext,
236    ) -> Result<proto::SynchronizeContextsResponse> {
237        this.update(&mut cx, |this, cx| {
238            if this.project.read(cx).is_via_collab() {
239                return Err(anyhow!("only the host can synchronize contexts"));
240            }
241
242            let mut local_versions = Vec::new();
243            for remote_version_proto in envelope.payload.contexts {
244                let remote_version = ContextVersion::from_proto(&remote_version_proto);
245                let context_id = ContextId::from_proto(remote_version_proto.context_id);
246                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
247                    let context = context.read(cx);
248                    let operations = context.serialize_ops(&remote_version, cx);
249                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
250                    let client = this.client.clone();
251                    let project_id = envelope.payload.project_id;
252                    cx.background_executor()
253                        .spawn(async move {
254                            let operations = operations.await;
255                            for operation in operations {
256                                client.send(proto::UpdateContext {
257                                    project_id,
258                                    context_id: context_id.to_proto(),
259                                    operation: Some(operation),
260                                })?;
261                            }
262                            anyhow::Ok(())
263                        })
264                        .detach_and_log_err(cx);
265                }
266            }
267
268            this.advertise_contexts(cx);
269
270            anyhow::Ok(proto::SynchronizeContextsResponse {
271                contexts: local_versions,
272            })
273        })?
274    }
275
276    fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
277        let is_shared = self.project.read(cx).is_shared();
278        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
279        if is_shared == was_shared {
280            return;
281        }
282
283        if is_shared {
284            self.contexts.retain_mut(|context| {
285                if let Some(strong_context) = context.upgrade() {
286                    *context = ContextHandle::Strong(strong_context);
287                    true
288                } else {
289                    false
290                }
291            });
292            let remote_id = self.project.read(cx).remote_id().unwrap();
293            self.client_subscription = self
294                .client
295                .subscribe_to_entity(remote_id)
296                .log_err()
297                .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
298            self.advertise_contexts(cx);
299        } else {
300            self.client_subscription = None;
301        }
302    }
303
304    fn handle_project_event(
305        &mut self,
306        _: Model<Project>,
307        event: &project::Event,
308        cx: &mut ModelContext<Self>,
309    ) {
310        match event {
311            project::Event::Reshared => {
312                self.advertise_contexts(cx);
313            }
314            project::Event::HostReshared | project::Event::Rejoined => {
315                self.synchronize_contexts(cx);
316            }
317            project::Event::DisconnectedFromHost => {
318                self.contexts.retain_mut(|context| {
319                    if let Some(strong_context) = context.upgrade() {
320                        *context = ContextHandle::Weak(context.downgrade());
321                        strong_context.update(cx, |context, cx| {
322                            if context.replica_id() != ReplicaId::default() {
323                                context.set_capability(language::Capability::ReadOnly, cx);
324                            }
325                        });
326                        true
327                    } else {
328                        false
329                    }
330                });
331                self.host_contexts.clear();
332                cx.notify();
333            }
334            _ => {}
335        }
336    }
337
338    pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
339        let context = cx.new_model(|cx| {
340            Context::local(
341                self.languages.clone(),
342                Some(self.project.clone()),
343                Some(self.telemetry.clone()),
344                self.prompt_builder.clone(),
345                cx,
346            )
347        });
348        self.register_context(&context, cx);
349        context
350    }
351
352    pub fn create_remote_context(
353        &mut self,
354        cx: &mut ModelContext<Self>,
355    ) -> Task<Result<Model<Context>>> {
356        let project = self.project.read(cx);
357        let Some(project_id) = project.remote_id() else {
358            return Task::ready(Err(anyhow!("project was not remote")));
359        };
360        if project.is_local_or_ssh() {
361            return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
362        }
363
364        let replica_id = project.replica_id();
365        let capability = project.capability();
366        let language_registry = self.languages.clone();
367        let project = self.project.clone();
368        let telemetry = self.telemetry.clone();
369        let prompt_builder = self.prompt_builder.clone();
370        let request = self.client.request(proto::CreateContext { project_id });
371        cx.spawn(|this, mut cx| async move {
372            let response = request.await?;
373            let context_id = ContextId::from_proto(response.context_id);
374            let context_proto = response.context.context("invalid context")?;
375            let context = cx.new_model(|cx| {
376                Context::new(
377                    context_id.clone(),
378                    replica_id,
379                    capability,
380                    language_registry,
381                    prompt_builder,
382                    Some(project),
383                    Some(telemetry),
384                    cx,
385                )
386            })?;
387            let operations = cx
388                .background_executor()
389                .spawn(async move {
390                    context_proto
391                        .operations
392                        .into_iter()
393                        .map(ContextOperation::from_proto)
394                        .collect::<Result<Vec<_>>>()
395                })
396                .await?;
397            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
398            this.update(&mut cx, |this, cx| {
399                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
400                    existing_context
401                } else {
402                    this.register_context(&context, cx);
403                    this.synchronize_contexts(cx);
404                    context
405                }
406            })
407        })
408    }
409
410    pub fn open_local_context(
411        &mut self,
412        path: PathBuf,
413        cx: &ModelContext<Self>,
414    ) -> Task<Result<Model<Context>>> {
415        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
416            return Task::ready(Ok(existing_context));
417        }
418
419        let fs = self.fs.clone();
420        let languages = self.languages.clone();
421        let project = self.project.clone();
422        let telemetry = self.telemetry.clone();
423        let load = cx.background_executor().spawn({
424            let path = path.clone();
425            async move {
426                let saved_context = fs.load(&path).await?;
427                SavedContext::from_json(&saved_context)
428            }
429        });
430        let prompt_builder = self.prompt_builder.clone();
431
432        cx.spawn(|this, mut cx| async move {
433            let saved_context = load.await?;
434            let context = cx.new_model(|cx| {
435                Context::deserialize(
436                    saved_context,
437                    path.clone(),
438                    languages,
439                    prompt_builder,
440                    Some(project),
441                    Some(telemetry),
442                    cx,
443                )
444            })?;
445            this.update(&mut cx, |this, cx| {
446                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
447                    existing_context
448                } else {
449                    this.register_context(&context, cx);
450                    context
451                }
452            })
453        })
454    }
455
456    fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
457        self.contexts.iter().find_map(|context| {
458            let context = context.upgrade()?;
459            if context.read(cx).path() == Some(path) {
460                Some(context)
461            } else {
462                None
463            }
464        })
465    }
466
467    pub(super) fn loaded_context_for_id(
468        &self,
469        id: &ContextId,
470        cx: &AppContext,
471    ) -> Option<Model<Context>> {
472        self.contexts.iter().find_map(|context| {
473            let context = context.upgrade()?;
474            if context.read(cx).id() == id {
475                Some(context)
476            } else {
477                None
478            }
479        })
480    }
481
482    pub fn open_remote_context(
483        &mut self,
484        context_id: ContextId,
485        cx: &mut ModelContext<Self>,
486    ) -> Task<Result<Model<Context>>> {
487        let project = self.project.read(cx);
488        let Some(project_id) = project.remote_id() else {
489            return Task::ready(Err(anyhow!("project was not remote")));
490        };
491        if project.is_local_or_ssh() {
492            return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
493        }
494
495        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
496            return Task::ready(Ok(context));
497        }
498
499        let replica_id = project.replica_id();
500        let capability = project.capability();
501        let language_registry = self.languages.clone();
502        let project = self.project.clone();
503        let telemetry = self.telemetry.clone();
504        let request = self.client.request(proto::OpenContext {
505            project_id,
506            context_id: context_id.to_proto(),
507        });
508        let prompt_builder = self.prompt_builder.clone();
509        cx.spawn(|this, mut cx| async move {
510            let response = request.await?;
511            let context_proto = response.context.context("invalid context")?;
512            let context = cx.new_model(|cx| {
513                Context::new(
514                    context_id.clone(),
515                    replica_id,
516                    capability,
517                    language_registry,
518                    prompt_builder,
519                    Some(project),
520                    Some(telemetry),
521                    cx,
522                )
523            })?;
524            let operations = cx
525                .background_executor()
526                .spawn(async move {
527                    context_proto
528                        .operations
529                        .into_iter()
530                        .map(ContextOperation::from_proto)
531                        .collect::<Result<Vec<_>>>()
532                })
533                .await?;
534            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
535            this.update(&mut cx, |this, cx| {
536                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
537                    existing_context
538                } else {
539                    this.register_context(&context, cx);
540                    this.synchronize_contexts(cx);
541                    context
542                }
543            })
544        })
545    }
546
547    fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
548        let handle = if self.project_is_shared {
549            ContextHandle::Strong(context.clone())
550        } else {
551            ContextHandle::Weak(context.downgrade())
552        };
553        self.contexts.push(handle);
554        self.advertise_contexts(cx);
555        cx.subscribe(context, Self::handle_context_event).detach();
556    }
557
558    fn handle_context_event(
559        &mut self,
560        context: Model<Context>,
561        event: &ContextEvent,
562        cx: &mut ModelContext<Self>,
563    ) {
564        let Some(project_id) = self.project.read(cx).remote_id() else {
565            return;
566        };
567
568        match event {
569            ContextEvent::SummaryChanged => {
570                self.advertise_contexts(cx);
571            }
572            ContextEvent::Operation(operation) => {
573                let context_id = context.read(cx).id().to_proto();
574                let operation = operation.to_proto();
575                self.client
576                    .send(proto::UpdateContext {
577                        project_id,
578                        context_id,
579                        operation: Some(operation),
580                    })
581                    .log_err();
582            }
583            _ => {}
584        }
585    }
586
587    fn advertise_contexts(&self, cx: &AppContext) {
588        let Some(project_id) = self.project.read(cx).remote_id() else {
589            return;
590        };
591
592        // For now, only the host can advertise their open contexts.
593        if self.project.read(cx).is_via_collab() {
594            return;
595        }
596
597        let contexts = self
598            .contexts
599            .iter()
600            .rev()
601            .filter_map(|context| {
602                let context = context.upgrade()?.read(cx);
603                if context.replica_id() == ReplicaId::default() {
604                    Some(proto::ContextMetadata {
605                        context_id: context.id().to_proto(),
606                        summary: context.summary().map(|summary| summary.text.clone()),
607                    })
608                } else {
609                    None
610                }
611            })
612            .collect();
613        self.client
614            .send(proto::AdvertiseContexts {
615                project_id,
616                contexts,
617            })
618            .ok();
619    }
620
621    fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
622        let Some(project_id) = self.project.read(cx).remote_id() else {
623            return;
624        };
625
626        let contexts = self
627            .contexts
628            .iter()
629            .filter_map(|context| {
630                let context = context.upgrade()?.read(cx);
631                if context.replica_id() != ReplicaId::default() {
632                    Some(context.version(cx).to_proto(context.id().clone()))
633                } else {
634                    None
635                }
636            })
637            .collect();
638
639        let client = self.client.clone();
640        let request = self.client.request(proto::SynchronizeContexts {
641            project_id,
642            contexts,
643        });
644        cx.spawn(|this, cx| async move {
645            let response = request.await?;
646
647            let mut context_ids = Vec::new();
648            let mut operations = Vec::new();
649            this.read_with(&cx, |this, cx| {
650                for context_version_proto in response.contexts {
651                    let context_version = ContextVersion::from_proto(&context_version_proto);
652                    let context_id = ContextId::from_proto(context_version_proto.context_id);
653                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
654                        context_ids.push(context_id);
655                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
656                    }
657                }
658            })?;
659
660            let operations = futures::future::join_all(operations).await;
661            for (context_id, operations) in context_ids.into_iter().zip(operations) {
662                for operation in operations {
663                    client.send(proto::UpdateContext {
664                        project_id,
665                        context_id: context_id.to_proto(),
666                        operation: Some(operation),
667                    })?;
668                }
669            }
670
671            anyhow::Ok(())
672        })
673        .detach_and_log_err(cx);
674    }
675
676    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
677        let metadata = self.contexts_metadata.clone();
678        let executor = cx.background_executor().clone();
679        cx.background_executor().spawn(async move {
680            if query.is_empty() {
681                metadata
682            } else {
683                let candidates = metadata
684                    .iter()
685                    .enumerate()
686                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
687                    .collect::<Vec<_>>();
688                let matches = fuzzy::match_strings(
689                    &candidates,
690                    &query,
691                    false,
692                    100,
693                    &Default::default(),
694                    executor,
695                )
696                .await;
697
698                matches
699                    .into_iter()
700                    .map(|mat| metadata[mat.candidate_id].clone())
701                    .collect()
702            }
703        })
704    }
705
706    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
707        &self.host_contexts
708    }
709
710    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
711        let fs = self.fs.clone();
712        cx.spawn(|this, mut cx| async move {
713            fs.create_dir(contexts_dir()).await?;
714
715            let mut paths = fs.read_dir(contexts_dir()).await?;
716            let mut contexts = Vec::<SavedContextMetadata>::new();
717            while let Some(path) = paths.next().await {
718                let path = path?;
719                if path.extension() != Some(OsStr::new("json")) {
720                    continue;
721                }
722
723                let pattern = r" - \d+.zed.json$";
724                let re = Regex::new(pattern).unwrap();
725
726                let metadata = fs.metadata(&path).await?;
727                if let Some((file_name, metadata)) = path
728                    .file_name()
729                    .and_then(|name| name.to_str())
730                    .zip(metadata)
731                {
732                    // This is used to filter out contexts saved by the new assistant.
733                    if !re.is_match(file_name) {
734                        continue;
735                    }
736
737                    if let Some(title) = re.replace(file_name, "").lines().next() {
738                        contexts.push(SavedContextMetadata {
739                            title: title.to_string(),
740                            path,
741                            mtime: metadata.mtime.into(),
742                        });
743                    }
744                }
745            }
746            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
747
748            this.update(&mut cx, |this, cx| {
749                this.contexts_metadata = contexts;
750                cx.notify();
751            })
752        })
753    }
754}