1use crate::{
2 prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion,
3 SavedContext, SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
7use clock::ReplicaId;
8use fs::Fs;
9use futures::StreamExt;
10use fuzzy::StringMatchCandidate;
11use gpui::{
12 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
13};
14use language::LanguageRegistry;
15use paths::contexts_dir;
16use project::Project;
17use regex::Regex;
18use rpc::AnyProtoClient;
19use std::{
20 cmp::Reverse,
21 ffi::OsStr,
22 mem,
23 path::{Path, PathBuf},
24 sync::Arc,
25 time::Duration,
26};
27use util::{ResultExt, TryFutureExt};
28
29pub fn init(client: &AnyProtoClient) {
30 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
31 client.add_model_request_handler(ContextStore::handle_open_context);
32 client.add_model_request_handler(ContextStore::handle_create_context);
33 client.add_model_message_handler(ContextStore::handle_update_context);
34 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
35}
36
37#[derive(Clone)]
38pub struct RemoteContextMetadata {
39 pub id: ContextId,
40 pub summary: Option<String>,
41}
42
43pub struct ContextStore {
44 contexts: Vec<ContextHandle>,
45 contexts_metadata: Vec<SavedContextMetadata>,
46 host_contexts: Vec<RemoteContextMetadata>,
47 fs: Arc<dyn Fs>,
48 languages: Arc<LanguageRegistry>,
49 telemetry: Arc<Telemetry>,
50 _watch_updates: Task<Option<()>>,
51 client: Arc<Client>,
52 project: Model<Project>,
53 project_is_shared: bool,
54 client_subscription: Option<client::Subscription>,
55 _project_subscriptions: Vec<gpui::Subscription>,
56 prompt_builder: Arc<PromptBuilder>,
57}
58
59pub enum ContextStoreEvent {
60 ContextCreated(ContextId),
61}
62
63impl EventEmitter<ContextStoreEvent> for ContextStore {}
64
65enum ContextHandle {
66 Weak(WeakModel<Context>),
67 Strong(Model<Context>),
68}
69
70impl ContextHandle {
71 fn upgrade(&self) -> Option<Model<Context>> {
72 match self {
73 ContextHandle::Weak(weak) => weak.upgrade(),
74 ContextHandle::Strong(strong) => Some(strong.clone()),
75 }
76 }
77
78 fn downgrade(&self) -> WeakModel<Context> {
79 match self {
80 ContextHandle::Weak(weak) => weak.clone(),
81 ContextHandle::Strong(strong) => strong.downgrade(),
82 }
83 }
84}
85
86impl ContextStore {
87 pub fn new(
88 project: Model<Project>,
89 prompt_builder: Arc<PromptBuilder>,
90 cx: &mut AppContext,
91 ) -> Task<Result<Model<Self>>> {
92 let fs = project.read(cx).fs().clone();
93 let languages = project.read(cx).languages().clone();
94 let telemetry = project.read(cx).client().telemetry().clone();
95 cx.spawn(|mut cx| async move {
96 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
97 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
98
99 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
100 let mut this = Self {
101 contexts: Vec::new(),
102 contexts_metadata: Vec::new(),
103 host_contexts: Vec::new(),
104 fs,
105 languages,
106 telemetry,
107 _watch_updates: cx.spawn(|this, mut cx| {
108 async move {
109 while events.next().await.is_some() {
110 this.update(&mut cx, |this, cx| this.reload(cx))?
111 .await
112 .log_err();
113 }
114 anyhow::Ok(())
115 }
116 .log_err()
117 }),
118 client_subscription: None,
119 _project_subscriptions: vec![
120 cx.observe(&project, Self::handle_project_changed),
121 cx.subscribe(&project, Self::handle_project_event),
122 ],
123 project_is_shared: false,
124 client: project.read(cx).client(),
125 project: project.clone(),
126 prompt_builder,
127 };
128 this.handle_project_changed(project, cx);
129 this.synchronize_contexts(cx);
130 this
131 })?;
132 this.update(&mut cx, |this, cx| this.reload(cx))?
133 .await
134 .log_err();
135 Ok(this)
136 })
137 }
138
139 async fn handle_advertise_contexts(
140 this: Model<Self>,
141 envelope: TypedEnvelope<proto::AdvertiseContexts>,
142 mut cx: AsyncAppContext,
143 ) -> Result<()> {
144 this.update(&mut cx, |this, cx| {
145 this.host_contexts = envelope
146 .payload
147 .contexts
148 .into_iter()
149 .map(|context| RemoteContextMetadata {
150 id: ContextId::from_proto(context.context_id),
151 summary: context.summary,
152 })
153 .collect();
154 cx.notify();
155 })
156 }
157
158 async fn handle_open_context(
159 this: Model<Self>,
160 envelope: TypedEnvelope<proto::OpenContext>,
161 mut cx: AsyncAppContext,
162 ) -> Result<proto::OpenContextResponse> {
163 let context_id = ContextId::from_proto(envelope.payload.context_id);
164 let operations = this.update(&mut cx, |this, cx| {
165 if this.project.read(cx).is_via_collab() {
166 return Err(anyhow!("only the host contexts can be opened"));
167 }
168
169 let context = this
170 .loaded_context_for_id(&context_id, cx)
171 .context("context not found")?;
172 if context.read(cx).replica_id() != ReplicaId::default() {
173 return Err(anyhow!("context must be opened via the host"));
174 }
175
176 anyhow::Ok(
177 context
178 .read(cx)
179 .serialize_ops(&ContextVersion::default(), cx),
180 )
181 })??;
182 let operations = operations.await;
183 Ok(proto::OpenContextResponse {
184 context: Some(proto::Context { operations }),
185 })
186 }
187
188 async fn handle_create_context(
189 this: Model<Self>,
190 _: TypedEnvelope<proto::CreateContext>,
191 mut cx: AsyncAppContext,
192 ) -> Result<proto::CreateContextResponse> {
193 let (context_id, operations) = this.update(&mut cx, |this, cx| {
194 if this.project.read(cx).is_via_collab() {
195 return Err(anyhow!("can only create contexts as the host"));
196 }
197
198 let context = this.create(cx);
199 let context_id = context.read(cx).id().clone();
200 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
201
202 anyhow::Ok((
203 context_id,
204 context
205 .read(cx)
206 .serialize_ops(&ContextVersion::default(), cx),
207 ))
208 })??;
209 let operations = operations.await;
210 Ok(proto::CreateContextResponse {
211 context_id: context_id.to_proto(),
212 context: Some(proto::Context { operations }),
213 })
214 }
215
216 async fn handle_update_context(
217 this: Model<Self>,
218 envelope: TypedEnvelope<proto::UpdateContext>,
219 mut cx: AsyncAppContext,
220 ) -> Result<()> {
221 this.update(&mut cx, |this, cx| {
222 let context_id = ContextId::from_proto(envelope.payload.context_id);
223 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
224 let operation_proto = envelope.payload.operation.context("invalid operation")?;
225 let operation = ContextOperation::from_proto(operation_proto)?;
226 context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
227 }
228 Ok(())
229 })?
230 }
231
232 async fn handle_synchronize_contexts(
233 this: Model<Self>,
234 envelope: TypedEnvelope<proto::SynchronizeContexts>,
235 mut cx: AsyncAppContext,
236 ) -> Result<proto::SynchronizeContextsResponse> {
237 this.update(&mut cx, |this, cx| {
238 if this.project.read(cx).is_via_collab() {
239 return Err(anyhow!("only the host can synchronize contexts"));
240 }
241
242 let mut local_versions = Vec::new();
243 for remote_version_proto in envelope.payload.contexts {
244 let remote_version = ContextVersion::from_proto(&remote_version_proto);
245 let context_id = ContextId::from_proto(remote_version_proto.context_id);
246 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
247 let context = context.read(cx);
248 let operations = context.serialize_ops(&remote_version, cx);
249 local_versions.push(context.version(cx).to_proto(context_id.clone()));
250 let client = this.client.clone();
251 let project_id = envelope.payload.project_id;
252 cx.background_executor()
253 .spawn(async move {
254 let operations = operations.await;
255 for operation in operations {
256 client.send(proto::UpdateContext {
257 project_id,
258 context_id: context_id.to_proto(),
259 operation: Some(operation),
260 })?;
261 }
262 anyhow::Ok(())
263 })
264 .detach_and_log_err(cx);
265 }
266 }
267
268 this.advertise_contexts(cx);
269
270 anyhow::Ok(proto::SynchronizeContextsResponse {
271 contexts: local_versions,
272 })
273 })?
274 }
275
276 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
277 let is_shared = self.project.read(cx).is_shared();
278 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
279 if is_shared == was_shared {
280 return;
281 }
282
283 if is_shared {
284 self.contexts.retain_mut(|context| {
285 if let Some(strong_context) = context.upgrade() {
286 *context = ContextHandle::Strong(strong_context);
287 true
288 } else {
289 false
290 }
291 });
292 let remote_id = self.project.read(cx).remote_id().unwrap();
293 self.client_subscription = self
294 .client
295 .subscribe_to_entity(remote_id)
296 .log_err()
297 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
298 self.advertise_contexts(cx);
299 } else {
300 self.client_subscription = None;
301 }
302 }
303
304 fn handle_project_event(
305 &mut self,
306 _: Model<Project>,
307 event: &project::Event,
308 cx: &mut ModelContext<Self>,
309 ) {
310 match event {
311 project::Event::Reshared => {
312 self.advertise_contexts(cx);
313 }
314 project::Event::HostReshared | project::Event::Rejoined => {
315 self.synchronize_contexts(cx);
316 }
317 project::Event::DisconnectedFromHost => {
318 self.contexts.retain_mut(|context| {
319 if let Some(strong_context) = context.upgrade() {
320 *context = ContextHandle::Weak(context.downgrade());
321 strong_context.update(cx, |context, cx| {
322 if context.replica_id() != ReplicaId::default() {
323 context.set_capability(language::Capability::ReadOnly, cx);
324 }
325 });
326 true
327 } else {
328 false
329 }
330 });
331 self.host_contexts.clear();
332 cx.notify();
333 }
334 _ => {}
335 }
336 }
337
338 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
339 let context = cx.new_model(|cx| {
340 Context::local(
341 self.languages.clone(),
342 Some(self.project.clone()),
343 Some(self.telemetry.clone()),
344 self.prompt_builder.clone(),
345 cx,
346 )
347 });
348 self.register_context(&context, cx);
349 context
350 }
351
352 pub fn create_remote_context(
353 &mut self,
354 cx: &mut ModelContext<Self>,
355 ) -> Task<Result<Model<Context>>> {
356 let project = self.project.read(cx);
357 let Some(project_id) = project.remote_id() else {
358 return Task::ready(Err(anyhow!("project was not remote")));
359 };
360 if project.is_local_or_ssh() {
361 return Task::ready(Err(anyhow!("cannot create remote contexts as the host")));
362 }
363
364 let replica_id = project.replica_id();
365 let capability = project.capability();
366 let language_registry = self.languages.clone();
367 let project = self.project.clone();
368 let telemetry = self.telemetry.clone();
369 let prompt_builder = self.prompt_builder.clone();
370 let request = self.client.request(proto::CreateContext { project_id });
371 cx.spawn(|this, mut cx| async move {
372 let response = request.await?;
373 let context_id = ContextId::from_proto(response.context_id);
374 let context_proto = response.context.context("invalid context")?;
375 let context = cx.new_model(|cx| {
376 Context::new(
377 context_id.clone(),
378 replica_id,
379 capability,
380 language_registry,
381 prompt_builder,
382 Some(project),
383 Some(telemetry),
384 cx,
385 )
386 })?;
387 let operations = cx
388 .background_executor()
389 .spawn(async move {
390 context_proto
391 .operations
392 .into_iter()
393 .map(ContextOperation::from_proto)
394 .collect::<Result<Vec<_>>>()
395 })
396 .await?;
397 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
398 this.update(&mut cx, |this, cx| {
399 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
400 existing_context
401 } else {
402 this.register_context(&context, cx);
403 this.synchronize_contexts(cx);
404 context
405 }
406 })
407 })
408 }
409
410 pub fn open_local_context(
411 &mut self,
412 path: PathBuf,
413 cx: &ModelContext<Self>,
414 ) -> Task<Result<Model<Context>>> {
415 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
416 return Task::ready(Ok(existing_context));
417 }
418
419 let fs = self.fs.clone();
420 let languages = self.languages.clone();
421 let project = self.project.clone();
422 let telemetry = self.telemetry.clone();
423 let load = cx.background_executor().spawn({
424 let path = path.clone();
425 async move {
426 let saved_context = fs.load(&path).await?;
427 SavedContext::from_json(&saved_context)
428 }
429 });
430 let prompt_builder = self.prompt_builder.clone();
431
432 cx.spawn(|this, mut cx| async move {
433 let saved_context = load.await?;
434 let context = cx.new_model(|cx| {
435 Context::deserialize(
436 saved_context,
437 path.clone(),
438 languages,
439 prompt_builder,
440 Some(project),
441 Some(telemetry),
442 cx,
443 )
444 })?;
445 this.update(&mut cx, |this, cx| {
446 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
447 existing_context
448 } else {
449 this.register_context(&context, cx);
450 context
451 }
452 })
453 })
454 }
455
456 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
457 self.contexts.iter().find_map(|context| {
458 let context = context.upgrade()?;
459 if context.read(cx).path() == Some(path) {
460 Some(context)
461 } else {
462 None
463 }
464 })
465 }
466
467 pub(super) fn loaded_context_for_id(
468 &self,
469 id: &ContextId,
470 cx: &AppContext,
471 ) -> Option<Model<Context>> {
472 self.contexts.iter().find_map(|context| {
473 let context = context.upgrade()?;
474 if context.read(cx).id() == id {
475 Some(context)
476 } else {
477 None
478 }
479 })
480 }
481
482 pub fn open_remote_context(
483 &mut self,
484 context_id: ContextId,
485 cx: &mut ModelContext<Self>,
486 ) -> Task<Result<Model<Context>>> {
487 let project = self.project.read(cx);
488 let Some(project_id) = project.remote_id() else {
489 return Task::ready(Err(anyhow!("project was not remote")));
490 };
491 if project.is_local_or_ssh() {
492 return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
493 }
494
495 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
496 return Task::ready(Ok(context));
497 }
498
499 let replica_id = project.replica_id();
500 let capability = project.capability();
501 let language_registry = self.languages.clone();
502 let project = self.project.clone();
503 let telemetry = self.telemetry.clone();
504 let request = self.client.request(proto::OpenContext {
505 project_id,
506 context_id: context_id.to_proto(),
507 });
508 let prompt_builder = self.prompt_builder.clone();
509 cx.spawn(|this, mut cx| async move {
510 let response = request.await?;
511 let context_proto = response.context.context("invalid context")?;
512 let context = cx.new_model(|cx| {
513 Context::new(
514 context_id.clone(),
515 replica_id,
516 capability,
517 language_registry,
518 prompt_builder,
519 Some(project),
520 Some(telemetry),
521 cx,
522 )
523 })?;
524 let operations = cx
525 .background_executor()
526 .spawn(async move {
527 context_proto
528 .operations
529 .into_iter()
530 .map(ContextOperation::from_proto)
531 .collect::<Result<Vec<_>>>()
532 })
533 .await?;
534 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
535 this.update(&mut cx, |this, cx| {
536 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
537 existing_context
538 } else {
539 this.register_context(&context, cx);
540 this.synchronize_contexts(cx);
541 context
542 }
543 })
544 })
545 }
546
547 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
548 let handle = if self.project_is_shared {
549 ContextHandle::Strong(context.clone())
550 } else {
551 ContextHandle::Weak(context.downgrade())
552 };
553 self.contexts.push(handle);
554 self.advertise_contexts(cx);
555 cx.subscribe(context, Self::handle_context_event).detach();
556 }
557
558 fn handle_context_event(
559 &mut self,
560 context: Model<Context>,
561 event: &ContextEvent,
562 cx: &mut ModelContext<Self>,
563 ) {
564 let Some(project_id) = self.project.read(cx).remote_id() else {
565 return;
566 };
567
568 match event {
569 ContextEvent::SummaryChanged => {
570 self.advertise_contexts(cx);
571 }
572 ContextEvent::Operation(operation) => {
573 let context_id = context.read(cx).id().to_proto();
574 let operation = operation.to_proto();
575 self.client
576 .send(proto::UpdateContext {
577 project_id,
578 context_id,
579 operation: Some(operation),
580 })
581 .log_err();
582 }
583 _ => {}
584 }
585 }
586
587 fn advertise_contexts(&self, cx: &AppContext) {
588 let Some(project_id) = self.project.read(cx).remote_id() else {
589 return;
590 };
591
592 // For now, only the host can advertise their open contexts.
593 if self.project.read(cx).is_via_collab() {
594 return;
595 }
596
597 let contexts = self
598 .contexts
599 .iter()
600 .rev()
601 .filter_map(|context| {
602 let context = context.upgrade()?.read(cx);
603 if context.replica_id() == ReplicaId::default() {
604 Some(proto::ContextMetadata {
605 context_id: context.id().to_proto(),
606 summary: context.summary().map(|summary| summary.text.clone()),
607 })
608 } else {
609 None
610 }
611 })
612 .collect();
613 self.client
614 .send(proto::AdvertiseContexts {
615 project_id,
616 contexts,
617 })
618 .ok();
619 }
620
621 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
622 let Some(project_id) = self.project.read(cx).remote_id() else {
623 return;
624 };
625
626 let contexts = self
627 .contexts
628 .iter()
629 .filter_map(|context| {
630 let context = context.upgrade()?.read(cx);
631 if context.replica_id() != ReplicaId::default() {
632 Some(context.version(cx).to_proto(context.id().clone()))
633 } else {
634 None
635 }
636 })
637 .collect();
638
639 let client = self.client.clone();
640 let request = self.client.request(proto::SynchronizeContexts {
641 project_id,
642 contexts,
643 });
644 cx.spawn(|this, cx| async move {
645 let response = request.await?;
646
647 let mut context_ids = Vec::new();
648 let mut operations = Vec::new();
649 this.read_with(&cx, |this, cx| {
650 for context_version_proto in response.contexts {
651 let context_version = ContextVersion::from_proto(&context_version_proto);
652 let context_id = ContextId::from_proto(context_version_proto.context_id);
653 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
654 context_ids.push(context_id);
655 operations.push(context.read(cx).serialize_ops(&context_version, cx));
656 }
657 }
658 })?;
659
660 let operations = futures::future::join_all(operations).await;
661 for (context_id, operations) in context_ids.into_iter().zip(operations) {
662 for operation in operations {
663 client.send(proto::UpdateContext {
664 project_id,
665 context_id: context_id.to_proto(),
666 operation: Some(operation),
667 })?;
668 }
669 }
670
671 anyhow::Ok(())
672 })
673 .detach_and_log_err(cx);
674 }
675
676 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
677 let metadata = self.contexts_metadata.clone();
678 let executor = cx.background_executor().clone();
679 cx.background_executor().spawn(async move {
680 if query.is_empty() {
681 metadata
682 } else {
683 let candidates = metadata
684 .iter()
685 .enumerate()
686 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
687 .collect::<Vec<_>>();
688 let matches = fuzzy::match_strings(
689 &candidates,
690 &query,
691 false,
692 100,
693 &Default::default(),
694 executor,
695 )
696 .await;
697
698 matches
699 .into_iter()
700 .map(|mat| metadata[mat.candidate_id].clone())
701 .collect()
702 }
703 })
704 }
705
706 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
707 &self.host_contexts
708 }
709
710 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
711 let fs = self.fs.clone();
712 cx.spawn(|this, mut cx| async move {
713 fs.create_dir(contexts_dir()).await?;
714
715 let mut paths = fs.read_dir(contexts_dir()).await?;
716 let mut contexts = Vec::<SavedContextMetadata>::new();
717 while let Some(path) = paths.next().await {
718 let path = path?;
719 if path.extension() != Some(OsStr::new("json")) {
720 continue;
721 }
722
723 let pattern = r" - \d+.zed.json$";
724 let re = Regex::new(pattern).unwrap();
725
726 let metadata = fs.metadata(&path).await?;
727 if let Some((file_name, metadata)) = path
728 .file_name()
729 .and_then(|name| name.to_str())
730 .zip(metadata)
731 {
732 // This is used to filter out contexts saved by the new assistant.
733 if !re.is_match(file_name) {
734 continue;
735 }
736
737 if let Some(title) = re.replace(file_name, "").lines().next() {
738 contexts.push(SavedContextMetadata {
739 title: title.to_string(),
740 path,
741 mtime: metadata.mtime.into(),
742 });
743 }
744 }
745 }
746 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
747
748 this.update(&mut cx, |this, cx| {
749 this.contexts_metadata = contexts;
750 cx.notify();
751 })
752 })
753 }
754}