1use crate::{
2 prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion,
3 SavedContext, SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
7use clock::ReplicaId;
8use fs::Fs;
9use futures::StreamExt;
10use fuzzy::StringMatchCandidate;
11use gpui::{
12 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
13};
14use language::LanguageRegistry;
15use paths::contexts_dir;
16use project::Project;
17use regex::Regex;
18use std::{
19 cmp::Reverse,
20 ffi::OsStr,
21 mem,
22 path::{Path, PathBuf},
23 sync::Arc,
24 time::Duration,
25};
26use util::{ResultExt, TryFutureExt};
27
28pub fn init(client: &Arc<Client>) {
29 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
30 client.add_model_request_handler(ContextStore::handle_open_context);
31 client.add_model_request_handler(ContextStore::handle_create_context);
32 client.add_model_message_handler(ContextStore::handle_update_context);
33 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
34}
35
36#[derive(Clone)]
37pub struct RemoteContextMetadata {
38 pub id: ContextId,
39 pub summary: Option<String>,
40}
41
42pub struct ContextStore {
43 contexts: Vec<ContextHandle>,
44 contexts_metadata: Vec<SavedContextMetadata>,
45 host_contexts: Vec<RemoteContextMetadata>,
46 fs: Arc<dyn Fs>,
47 languages: Arc<LanguageRegistry>,
48 telemetry: Arc<Telemetry>,
49 _watch_updates: Task<Option<()>>,
50 client: Arc<Client>,
51 project: Model<Project>,
52 project_is_shared: bool,
53 client_subscription: Option<client::Subscription>,
54 _project_subscriptions: Vec<gpui::Subscription>,
55 prompt_builder: Arc<PromptBuilder>,
56}
57
58pub enum ContextStoreEvent {
59 ContextCreated(ContextId),
60}
61
62impl EventEmitter<ContextStoreEvent> for ContextStore {}
63
64enum ContextHandle {
65 Weak(WeakModel<Context>),
66 Strong(Model<Context>),
67}
68
69impl ContextHandle {
70 fn upgrade(&self) -> Option<Model<Context>> {
71 match self {
72 ContextHandle::Weak(weak) => weak.upgrade(),
73 ContextHandle::Strong(strong) => Some(strong.clone()),
74 }
75 }
76
77 fn downgrade(&self) -> WeakModel<Context> {
78 match self {
79 ContextHandle::Weak(weak) => weak.clone(),
80 ContextHandle::Strong(strong) => strong.downgrade(),
81 }
82 }
83}
84
85impl ContextStore {
86 pub fn new(
87 project: Model<Project>,
88 prompt_builder: Arc<PromptBuilder>,
89 cx: &mut AppContext,
90 ) -> Task<Result<Model<Self>>> {
91 let fs = project.read(cx).fs().clone();
92 let languages = project.read(cx).languages().clone();
93 let telemetry = project.read(cx).client().telemetry().clone();
94 cx.spawn(|mut cx| async move {
95 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
96 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
97
98 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
99 let mut this = Self {
100 contexts: Vec::new(),
101 contexts_metadata: Vec::new(),
102 host_contexts: Vec::new(),
103 fs,
104 languages,
105 telemetry,
106 _watch_updates: cx.spawn(|this, mut cx| {
107 async move {
108 while events.next().await.is_some() {
109 this.update(&mut cx, |this, cx| this.reload(cx))?
110 .await
111 .log_err();
112 }
113 anyhow::Ok(())
114 }
115 .log_err()
116 }),
117 client_subscription: None,
118 _project_subscriptions: vec![
119 cx.observe(&project, Self::handle_project_changed),
120 cx.subscribe(&project, Self::handle_project_event),
121 ],
122 project_is_shared: false,
123 client: project.read(cx).client(),
124 project: project.clone(),
125 prompt_builder,
126 };
127 this.handle_project_changed(project, cx);
128 this.synchronize_contexts(cx);
129 this
130 })?;
131 this.update(&mut cx, |this, cx| this.reload(cx))?
132 .await
133 .log_err();
134 Ok(this)
135 })
136 }
137
138 async fn handle_advertise_contexts(
139 this: Model<Self>,
140 envelope: TypedEnvelope<proto::AdvertiseContexts>,
141 mut cx: AsyncAppContext,
142 ) -> Result<()> {
143 this.update(&mut cx, |this, cx| {
144 this.host_contexts = envelope
145 .payload
146 .contexts
147 .into_iter()
148 .map(|context| RemoteContextMetadata {
149 id: ContextId::from_proto(context.context_id),
150 summary: context.summary,
151 })
152 .collect();
153 cx.notify();
154 })
155 }
156
157 async fn handle_open_context(
158 this: Model<Self>,
159 envelope: TypedEnvelope<proto::OpenContext>,
160 mut cx: AsyncAppContext,
161 ) -> Result<proto::OpenContextResponse> {
162 let context_id = ContextId::from_proto(envelope.payload.context_id);
163 let operations = this.update(&mut cx, |this, cx| {
164 if this.project.read(cx).is_via_collab() {
165 return Err(anyhow!("only the host contexts can be opened"));
166 }
167
168 let context = this
169 .loaded_context_for_id(&context_id, cx)
170 .context("context not found")?;
171 if context.read(cx).replica_id() != ReplicaId::default() {
172 return Err(anyhow!("context must be opened via the host"));
173 }
174
175 anyhow::Ok(
176 context
177 .read(cx)
178 .serialize_ops(&ContextVersion::default(), cx),
179 )
180 })??;
181 let operations = operations.await;
182 Ok(proto::OpenContextResponse {
183 context: Some(proto::Context { operations }),
184 })
185 }
186
187 async fn handle_create_context(
188 this: Model<Self>,
189 _: TypedEnvelope<proto::CreateContext>,
190 mut cx: AsyncAppContext,
191 ) -> Result<proto::CreateContextResponse> {
192 let (context_id, operations) = this.update(&mut cx, |this, cx| {
193 if this.project.read(cx).is_via_collab() {
194 return Err(anyhow!("can only create contexts as the host"));
195 }
196
197 let context = this.create(cx);
198 let context_id = context.read(cx).id().clone();
199 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
200
201 anyhow::Ok((
202 context_id,
203 context
204 .read(cx)
205 .serialize_ops(&ContextVersion::default(), cx),
206 ))
207 })??;
208 let operations = operations.await;
209 Ok(proto::CreateContextResponse {
210 context_id: context_id.to_proto(),
211 context: Some(proto::Context { operations }),
212 })
213 }
214
215 async fn handle_update_context(
216 this: Model<Self>,
217 envelope: TypedEnvelope<proto::UpdateContext>,
218 mut cx: AsyncAppContext,
219 ) -> Result<()> {
220 this.update(&mut cx, |this, cx| {
221 let context_id = ContextId::from_proto(envelope.payload.context_id);
222 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
223 let operation_proto = envelope.payload.operation.context("invalid operation")?;
224 let operation = ContextOperation::from_proto(operation_proto)?;
225 context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
226 }
227 Ok(())
228 })?
229 }
230
231 async fn handle_synchronize_contexts(
232 this: Model<Self>,
233 envelope: TypedEnvelope<proto::SynchronizeContexts>,
234 mut cx: AsyncAppContext,
235 ) -> Result<proto::SynchronizeContextsResponse> {
236 this.update(&mut cx, |this, cx| {
237 if this.project.read(cx).is_via_collab() {
238 return Err(anyhow!("only the host can synchronize contexts"));
239 }
240
241 let mut local_versions = Vec::new();
242 for remote_version_proto in envelope.payload.contexts {
243 let remote_version = ContextVersion::from_proto(&remote_version_proto);
244 let context_id = ContextId::from_proto(remote_version_proto.context_id);
245 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
246 let context = context.read(cx);
247 let operations = context.serialize_ops(&remote_version, cx);
248 local_versions.push(context.version(cx).to_proto(context_id.clone()));
249 let client = this.client.clone();
250 let project_id = envelope.payload.project_id;
251 cx.background_executor()
252 .spawn(async move {
253 let operations = operations.await;
254 for operation in operations {
255 client.send(proto::UpdateContext {
256 project_id,
257 context_id: context_id.to_proto(),
258 operation: Some(operation),
259 })?;
260 }
261 anyhow::Ok(())
262 })
263 .detach_and_log_err(cx);
264 }
265 }
266
267 this.advertise_contexts(cx);
268
269 anyhow::Ok(proto::SynchronizeContextsResponse {
270 contexts: local_versions,
271 })
272 })?
273 }
274
275 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
276 let is_shared = self.project.read(cx).is_shared();
277 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
278 if is_shared == was_shared {
279 return;
280 }
281
282 if is_shared {
283 self.contexts.retain_mut(|context| {
284 if let Some(strong_context) = context.upgrade() {
285 *context = ContextHandle::Strong(strong_context);
286 true
287 } else {
288 false
289 }
290 });
291 let remote_id = self.project.read(cx).remote_id().unwrap();
292 self.client_subscription = self
293 .client
294 .subscribe_to_entity(remote_id)
295 .log_err()
296 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
297 self.advertise_contexts(cx);
298 } else {
299 self.client_subscription = None;
300 }
301 }
302
303 fn handle_project_event(
304 &mut self,
305 _: Model<Project>,
306 event: &project::Event,
307 cx: &mut ModelContext<Self>,
308 ) {
309 match event {
310 project::Event::Reshared => {
311 self.advertise_contexts(cx);
312 }
313 project::Event::HostReshared | project::Event::Rejoined => {
314 self.synchronize_contexts(cx);
315 }
316 project::Event::DisconnectedFromHost => {
317 self.contexts.retain_mut(|context| {
318 if let Some(strong_context) = context.upgrade() {
319 *context = ContextHandle::Weak(context.downgrade());
320 strong_context.update(cx, |context, cx| {
321 if context.replica_id() != ReplicaId::default() {
322 context.set_capability(language::Capability::ReadOnly, cx);
323 }
324 });
325 true
326 } else {
327 false
328 }
329 });
330 self.host_contexts.clear();
331 cx.notify();
332 }
333 _ => {}
334 }
335 }
336
337 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
338 let context = cx.new_model(|cx| {
339 Context::local(
340 self.languages.clone(),
341 Some(self.project.clone()),
342 Some(self.telemetry.clone()),
343 self.prompt_builder.clone(),
344 cx,
345 )
346 });
347 self.register_context(&context, cx);
348 context
349 }
350
351 pub fn create_remote_context(
352 &mut self,
353 cx: &mut ModelContext<Self>,
354 ) -> Task<Result<Model<Context>>> {
355 let project = self.project.read(cx);
356 let Some(project_id) = project.remote_id() else {
357 return Task::ready(Err(anyhow!("project was not remote")));
358 };
359 if project.is_local_or_ssh() {
360 return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
361 }
362
363 let replica_id = project.replica_id();
364 let capability = project.capability();
365 let language_registry = self.languages.clone();
366 let project = self.project.clone();
367 let telemetry = self.telemetry.clone();
368 let prompt_builder = self.prompt_builder.clone();
369 let request = self.client.request(proto::CreateContext { project_id });
370 cx.spawn(|this, mut cx| async move {
371 let response = request.await?;
372 let context_id = ContextId::from_proto(response.context_id);
373 let context_proto = response.context.context("invalid context")?;
374 let context = cx.new_model(|cx| {
375 Context::new(
376 context_id.clone(),
377 replica_id,
378 capability,
379 language_registry,
380 prompt_builder,
381 Some(project),
382 Some(telemetry),
383 cx,
384 )
385 })?;
386 let operations = cx
387 .background_executor()
388 .spawn(async move {
389 context_proto
390 .operations
391 .into_iter()
392 .map(|op| ContextOperation::from_proto(op))
393 .collect::<Result<Vec<_>>>()
394 })
395 .await?;
396 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
397 this.update(&mut cx, |this, cx| {
398 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
399 existing_context
400 } else {
401 this.register_context(&context, cx);
402 this.synchronize_contexts(cx);
403 context
404 }
405 })
406 })
407 }
408
409 pub fn open_local_context(
410 &mut self,
411 path: PathBuf,
412 cx: &ModelContext<Self>,
413 ) -> Task<Result<Model<Context>>> {
414 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
415 return Task::ready(Ok(existing_context));
416 }
417
418 let fs = self.fs.clone();
419 let languages = self.languages.clone();
420 let project = self.project.clone();
421 let telemetry = self.telemetry.clone();
422 let load = cx.background_executor().spawn({
423 let path = path.clone();
424 async move {
425 let saved_context = fs.load(&path).await?;
426 SavedContext::from_json(&saved_context)
427 }
428 });
429 let prompt_builder = self.prompt_builder.clone();
430
431 cx.spawn(|this, mut cx| async move {
432 let saved_context = load.await?;
433 let context = cx.new_model(|cx| {
434 Context::deserialize(
435 saved_context,
436 path.clone(),
437 languages,
438 prompt_builder,
439 Some(project),
440 Some(telemetry),
441 cx,
442 )
443 })?;
444 this.update(&mut cx, |this, cx| {
445 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
446 existing_context
447 } else {
448 this.register_context(&context, cx);
449 context
450 }
451 })
452 })
453 }
454
455 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
456 self.contexts.iter().find_map(|context| {
457 let context = context.upgrade()?;
458 if context.read(cx).path() == Some(path) {
459 Some(context)
460 } else {
461 None
462 }
463 })
464 }
465
466 pub(super) fn loaded_context_for_id(
467 &self,
468 id: &ContextId,
469 cx: &AppContext,
470 ) -> Option<Model<Context>> {
471 self.contexts.iter().find_map(|context| {
472 let context = context.upgrade()?;
473 if context.read(cx).id() == id {
474 Some(context)
475 } else {
476 None
477 }
478 })
479 }
480
481 pub fn open_remote_context(
482 &mut self,
483 context_id: ContextId,
484 cx: &mut ModelContext<Self>,
485 ) -> Task<Result<Model<Context>>> {
486 let project = self.project.read(cx);
487 let Some(project_id) = project.remote_id() else {
488 return Task::ready(Err(anyhow!("project was not remote")));
489 };
490 if project.is_local_or_ssh() {
491 return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
492 }
493
494 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
495 return Task::ready(Ok(context));
496 }
497
498 let replica_id = project.replica_id();
499 let capability = project.capability();
500 let language_registry = self.languages.clone();
501 let project = self.project.clone();
502 let telemetry = self.telemetry.clone();
503 let request = self.client.request(proto::OpenContext {
504 project_id,
505 context_id: context_id.to_proto(),
506 });
507 let prompt_builder = self.prompt_builder.clone();
508 cx.spawn(|this, mut cx| async move {
509 let response = request.await?;
510 let context_proto = response.context.context("invalid context")?;
511 let context = cx.new_model(|cx| {
512 Context::new(
513 context_id.clone(),
514 replica_id,
515 capability,
516 language_registry,
517 prompt_builder,
518 Some(project),
519 Some(telemetry),
520 cx,
521 )
522 })?;
523 let operations = cx
524 .background_executor()
525 .spawn(async move {
526 context_proto
527 .operations
528 .into_iter()
529 .map(|op| ContextOperation::from_proto(op))
530 .collect::<Result<Vec<_>>>()
531 })
532 .await?;
533 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
534 this.update(&mut cx, |this, cx| {
535 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
536 existing_context
537 } else {
538 this.register_context(&context, cx);
539 this.synchronize_contexts(cx);
540 context
541 }
542 })
543 })
544 }
545
546 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
547 let handle = if self.project_is_shared {
548 ContextHandle::Strong(context.clone())
549 } else {
550 ContextHandle::Weak(context.downgrade())
551 };
552 self.contexts.push(handle);
553 self.advertise_contexts(cx);
554 cx.subscribe(context, Self::handle_context_event).detach();
555 }
556
557 fn handle_context_event(
558 &mut self,
559 context: Model<Context>,
560 event: &ContextEvent,
561 cx: &mut ModelContext<Self>,
562 ) {
563 let Some(project_id) = self.project.read(cx).remote_id() else {
564 return;
565 };
566
567 match event {
568 ContextEvent::SummaryChanged => {
569 self.advertise_contexts(cx);
570 }
571 ContextEvent::Operation(operation) => {
572 let context_id = context.read(cx).id().to_proto();
573 let operation = operation.to_proto();
574 self.client
575 .send(proto::UpdateContext {
576 project_id,
577 context_id,
578 operation: Some(operation),
579 })
580 .log_err();
581 }
582 _ => {}
583 }
584 }
585
586 fn advertise_contexts(&self, cx: &AppContext) {
587 let Some(project_id) = self.project.read(cx).remote_id() else {
588 return;
589 };
590
591 // For now, only the host can advertise their open contexts.
592 if self.project.read(cx).is_via_collab() {
593 return;
594 }
595
596 let contexts = self
597 .contexts
598 .iter()
599 .rev()
600 .filter_map(|context| {
601 let context = context.upgrade()?.read(cx);
602 if context.replica_id() == ReplicaId::default() {
603 Some(proto::ContextMetadata {
604 context_id: context.id().to_proto(),
605 summary: context.summary().map(|summary| summary.text.clone()),
606 })
607 } else {
608 None
609 }
610 })
611 .collect();
612 self.client
613 .send(proto::AdvertiseContexts {
614 project_id,
615 contexts,
616 })
617 .ok();
618 }
619
620 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
621 let Some(project_id) = self.project.read(cx).remote_id() else {
622 return;
623 };
624
625 let contexts = self
626 .contexts
627 .iter()
628 .filter_map(|context| {
629 let context = context.upgrade()?.read(cx);
630 if context.replica_id() != ReplicaId::default() {
631 Some(context.version(cx).to_proto(context.id().clone()))
632 } else {
633 None
634 }
635 })
636 .collect();
637
638 let client = self.client.clone();
639 let request = self.client.request(proto::SynchronizeContexts {
640 project_id,
641 contexts,
642 });
643 cx.spawn(|this, cx| async move {
644 let response = request.await?;
645
646 let mut context_ids = Vec::new();
647 let mut operations = Vec::new();
648 this.read_with(&cx, |this, cx| {
649 for context_version_proto in response.contexts {
650 let context_version = ContextVersion::from_proto(&context_version_proto);
651 let context_id = ContextId::from_proto(context_version_proto.context_id);
652 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
653 context_ids.push(context_id);
654 operations.push(context.read(cx).serialize_ops(&context_version, cx));
655 }
656 }
657 })?;
658
659 let operations = futures::future::join_all(operations).await;
660 for (context_id, operations) in context_ids.into_iter().zip(operations) {
661 for operation in operations {
662 client.send(proto::UpdateContext {
663 project_id,
664 context_id: context_id.to_proto(),
665 operation: Some(operation),
666 })?;
667 }
668 }
669
670 anyhow::Ok(())
671 })
672 .detach_and_log_err(cx);
673 }
674
675 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
676 let metadata = self.contexts_metadata.clone();
677 let executor = cx.background_executor().clone();
678 cx.background_executor().spawn(async move {
679 if query.is_empty() {
680 metadata
681 } else {
682 let candidates = metadata
683 .iter()
684 .enumerate()
685 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
686 .collect::<Vec<_>>();
687 let matches = fuzzy::match_strings(
688 &candidates,
689 &query,
690 false,
691 100,
692 &Default::default(),
693 executor,
694 )
695 .await;
696
697 matches
698 .into_iter()
699 .map(|mat| metadata[mat.candidate_id].clone())
700 .collect()
701 }
702 })
703 }
704
705 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
706 &self.host_contexts
707 }
708
709 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
710 let fs = self.fs.clone();
711 cx.spawn(|this, mut cx| async move {
712 fs.create_dir(contexts_dir()).await?;
713
714 let mut paths = fs.read_dir(contexts_dir()).await?;
715 let mut contexts = Vec::<SavedContextMetadata>::new();
716 while let Some(path) = paths.next().await {
717 let path = path?;
718 if path.extension() != Some(OsStr::new("json")) {
719 continue;
720 }
721
722 let pattern = r" - \d+.zed.json$";
723 let re = Regex::new(pattern).unwrap();
724
725 let metadata = fs.metadata(&path).await?;
726 if let Some((file_name, metadata)) = path
727 .file_name()
728 .and_then(|name| name.to_str())
729 .zip(metadata)
730 {
731 // This is used to filter out contexts saved by the new assistant.
732 if !re.is_match(file_name) {
733 continue;
734 }
735
736 if let Some(title) = re.replace(file_name, "").lines().next() {
737 contexts.push(SavedContextMetadata {
738 title: title.to_string(),
739 path,
740 mtime: metadata.mtime.into(),
741 });
742 }
743 }
744 }
745 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
746
747 this.update(&mut cx, |this, cx| {
748 this.contexts_metadata = contexts;
749 cx.notify();
750 })
751 })
752 }
753}