1use crate::{
2 Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
7use clock::ReplicaId;
8use fs::Fs;
9use futures::StreamExt;
10use fuzzy::StringMatchCandidate;
11use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel};
12use language::LanguageRegistry;
13use paths::contexts_dir;
14use project::Project;
15use regex::Regex;
16use std::{
17 cmp::Reverse,
18 ffi::OsStr,
19 mem,
20 path::{Path, PathBuf},
21 sync::Arc,
22 time::Duration,
23};
24use util::{ResultExt, TryFutureExt};
25
26pub fn init(client: &Arc<Client>) {
27 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
28 client.add_model_request_handler(ContextStore::handle_open_context);
29 client.add_model_message_handler(ContextStore::handle_update_context);
30 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
31}
32
33#[derive(Clone)]
34pub struct RemoteContextMetadata {
35 pub id: ContextId,
36 pub summary: Option<String>,
37}
38
39pub struct ContextStore {
40 contexts: Vec<ContextHandle>,
41 contexts_metadata: Vec<SavedContextMetadata>,
42 host_contexts: Vec<RemoteContextMetadata>,
43 fs: Arc<dyn Fs>,
44 languages: Arc<LanguageRegistry>,
45 telemetry: Arc<Telemetry>,
46 _watch_updates: Task<Option<()>>,
47 client: Arc<Client>,
48 project: Model<Project>,
49 project_is_shared: bool,
50 client_subscription: Option<client::Subscription>,
51 _project_subscriptions: Vec<gpui::Subscription>,
52}
53
54enum ContextHandle {
55 Weak(WeakModel<Context>),
56 Strong(Model<Context>),
57}
58
59impl ContextHandle {
60 fn upgrade(&self) -> Option<Model<Context>> {
61 match self {
62 ContextHandle::Weak(weak) => weak.upgrade(),
63 ContextHandle::Strong(strong) => Some(strong.clone()),
64 }
65 }
66
67 fn downgrade(&self) -> WeakModel<Context> {
68 match self {
69 ContextHandle::Weak(weak) => weak.clone(),
70 ContextHandle::Strong(strong) => strong.downgrade(),
71 }
72 }
73}
74
75impl ContextStore {
76 pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
77 let fs = project.read(cx).fs().clone();
78 let languages = project.read(cx).languages().clone();
79 let telemetry = project.read(cx).client().telemetry().clone();
80 cx.spawn(|mut cx| async move {
81 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
82 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
83
84 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
85 let mut this = Self {
86 contexts: Vec::new(),
87 contexts_metadata: Vec::new(),
88 host_contexts: Vec::new(),
89 fs,
90 languages,
91 telemetry,
92 _watch_updates: cx.spawn(|this, mut cx| {
93 async move {
94 while events.next().await.is_some() {
95 this.update(&mut cx, |this, cx| this.reload(cx))?
96 .await
97 .log_err();
98 }
99 anyhow::Ok(())
100 }
101 .log_err()
102 }),
103 client_subscription: None,
104 _project_subscriptions: vec![
105 cx.observe(&project, Self::handle_project_changed),
106 cx.subscribe(&project, Self::handle_project_event),
107 ],
108 project_is_shared: false,
109 client: project.read(cx).client(),
110 project: project.clone(),
111 };
112 this.handle_project_changed(project, cx);
113 this.synchronize_contexts(cx);
114 this
115 })?;
116 this.update(&mut cx, |this, cx| this.reload(cx))?
117 .await
118 .log_err();
119 Ok(this)
120 })
121 }
122
123 async fn handle_advertise_contexts(
124 this: Model<Self>,
125 envelope: TypedEnvelope<proto::AdvertiseContexts>,
126 mut cx: AsyncAppContext,
127 ) -> Result<()> {
128 this.update(&mut cx, |this, cx| {
129 this.host_contexts = envelope
130 .payload
131 .contexts
132 .into_iter()
133 .map(|context| RemoteContextMetadata {
134 id: ContextId::from_proto(context.context_id),
135 summary: context.summary,
136 })
137 .collect();
138 cx.notify();
139 })
140 }
141
142 async fn handle_open_context(
143 this: Model<Self>,
144 envelope: TypedEnvelope<proto::OpenContext>,
145 mut cx: AsyncAppContext,
146 ) -> Result<proto::OpenContextResponse> {
147 let context_id = ContextId::from_proto(envelope.payload.context_id);
148 let operations = this.update(&mut cx, |this, cx| {
149 if this.project.read(cx).is_remote() {
150 return Err(anyhow!("only the host contexts can be opened"));
151 }
152
153 let context = this
154 .loaded_context_for_id(&context_id, cx)
155 .context("context not found")?;
156 if context.read(cx).replica_id() != ReplicaId::default() {
157 return Err(anyhow!("context must be opened via the host"));
158 }
159
160 anyhow::Ok(
161 context
162 .read(cx)
163 .serialize_ops(&ContextVersion::default(), cx),
164 )
165 })??;
166 let operations = operations.await;
167 Ok(proto::OpenContextResponse {
168 context: Some(proto::Context { operations }),
169 })
170 }
171
172 async fn handle_update_context(
173 this: Model<Self>,
174 envelope: TypedEnvelope<proto::UpdateContext>,
175 mut cx: AsyncAppContext,
176 ) -> Result<()> {
177 this.update(&mut cx, |this, cx| {
178 let context_id = ContextId::from_proto(envelope.payload.context_id);
179 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
180 let operation_proto = envelope.payload.operation.context("invalid operation")?;
181 let operation = ContextOperation::from_proto(operation_proto)?;
182 context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
183 }
184 Ok(())
185 })?
186 }
187
188 async fn handle_synchronize_contexts(
189 this: Model<Self>,
190 envelope: TypedEnvelope<proto::SynchronizeContexts>,
191 mut cx: AsyncAppContext,
192 ) -> Result<proto::SynchronizeContextsResponse> {
193 this.update(&mut cx, |this, cx| {
194 if this.project.read(cx).is_remote() {
195 return Err(anyhow!("only the host can synchronize contexts"));
196 }
197
198 let mut local_versions = Vec::new();
199 for remote_version_proto in envelope.payload.contexts {
200 let remote_version = ContextVersion::from_proto(&remote_version_proto);
201 let context_id = ContextId::from_proto(remote_version_proto.context_id);
202 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
203 let context = context.read(cx);
204 let operations = context.serialize_ops(&remote_version, cx);
205 local_versions.push(context.version(cx).to_proto(context_id.clone()));
206 let client = this.client.clone();
207 let project_id = envelope.payload.project_id;
208 cx.background_executor()
209 .spawn(async move {
210 let operations = operations.await;
211 for operation in operations {
212 client.send(proto::UpdateContext {
213 project_id,
214 context_id: context_id.to_proto(),
215 operation: Some(operation),
216 })?;
217 }
218 anyhow::Ok(())
219 })
220 .detach_and_log_err(cx);
221 }
222 }
223
224 this.advertise_contexts(cx);
225
226 anyhow::Ok(proto::SynchronizeContextsResponse {
227 contexts: local_versions,
228 })
229 })?
230 }
231
232 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
233 let is_shared = self.project.read(cx).is_shared();
234 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
235 if is_shared == was_shared {
236 return;
237 }
238
239 if is_shared {
240 self.contexts.retain_mut(|context| {
241 if let Some(strong_context) = context.upgrade() {
242 *context = ContextHandle::Strong(strong_context);
243 true
244 } else {
245 false
246 }
247 });
248 let remote_id = self.project.read(cx).remote_id().unwrap();
249 self.client_subscription = self
250 .client
251 .subscribe_to_entity(remote_id)
252 .log_err()
253 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
254 self.advertise_contexts(cx);
255 } else {
256 self.client_subscription = None;
257 }
258 }
259
260 fn handle_project_event(
261 &mut self,
262 _: Model<Project>,
263 event: &project::Event,
264 cx: &mut ModelContext<Self>,
265 ) {
266 match event {
267 project::Event::Reshared => {
268 self.advertise_contexts(cx);
269 }
270 project::Event::HostReshared | project::Event::Rejoined => {
271 self.synchronize_contexts(cx);
272 }
273 project::Event::DisconnectedFromHost => {
274 self.contexts.retain_mut(|context| {
275 if let Some(strong_context) = context.upgrade() {
276 *context = ContextHandle::Weak(context.downgrade());
277 strong_context.update(cx, |context, cx| {
278 if context.replica_id() != ReplicaId::default() {
279 context.set_capability(language::Capability::ReadOnly, cx);
280 }
281 });
282 true
283 } else {
284 false
285 }
286 });
287 self.host_contexts.clear();
288 cx.notify();
289 }
290 _ => {}
291 }
292 }
293
294 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
295 let context = cx.new_model(|cx| {
296 Context::local(self.languages.clone(), Some(self.telemetry.clone()), cx)
297 });
298 self.register_context(&context, cx);
299 context
300 }
301
302 pub fn open_local_context(
303 &mut self,
304 path: PathBuf,
305 cx: &ModelContext<Self>,
306 ) -> Task<Result<Model<Context>>> {
307 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
308 return Task::ready(Ok(existing_context));
309 }
310
311 let fs = self.fs.clone();
312 let languages = self.languages.clone();
313 let telemetry = self.telemetry.clone();
314 let load = cx.background_executor().spawn({
315 let path = path.clone();
316 async move {
317 let saved_context = fs.load(&path).await?;
318 SavedContext::from_json(&saved_context)
319 }
320 });
321
322 cx.spawn(|this, mut cx| async move {
323 let saved_context = load.await?;
324 let context = cx.new_model(|cx| {
325 Context::deserialize(saved_context, path.clone(), languages, Some(telemetry), cx)
326 })?;
327 this.update(&mut cx, |this, cx| {
328 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
329 existing_context
330 } else {
331 this.register_context(&context, cx);
332 context
333 }
334 })
335 })
336 }
337
338 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
339 self.contexts.iter().find_map(|context| {
340 let context = context.upgrade()?;
341 if context.read(cx).path() == Some(path) {
342 Some(context)
343 } else {
344 None
345 }
346 })
347 }
348
349 fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option<Model<Context>> {
350 self.contexts.iter().find_map(|context| {
351 let context = context.upgrade()?;
352 if context.read(cx).id() == id {
353 Some(context)
354 } else {
355 None
356 }
357 })
358 }
359
360 pub fn open_remote_context(
361 &mut self,
362 context_id: ContextId,
363 cx: &mut ModelContext<Self>,
364 ) -> Task<Result<Model<Context>>> {
365 let project = self.project.read(cx);
366 let Some(project_id) = project.remote_id() else {
367 return Task::ready(Err(anyhow!("project was not remote")));
368 };
369 if project.is_local() {
370 return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
371 }
372
373 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
374 return Task::ready(Ok(context));
375 }
376
377 let replica_id = project.replica_id();
378 let capability = project.capability();
379 let language_registry = self.languages.clone();
380 let telemetry = self.telemetry.clone();
381 let request = self.client.request(proto::OpenContext {
382 project_id,
383 context_id: context_id.to_proto(),
384 });
385 cx.spawn(|this, mut cx| async move {
386 let response = request.await?;
387 let context_proto = response.context.context("invalid context")?;
388 let context = cx.new_model(|cx| {
389 Context::new(
390 context_id.clone(),
391 replica_id,
392 capability,
393 language_registry,
394 Some(telemetry),
395 cx,
396 )
397 })?;
398 let operations = cx
399 .background_executor()
400 .spawn(async move {
401 context_proto
402 .operations
403 .into_iter()
404 .map(|op| ContextOperation::from_proto(op))
405 .collect::<Result<Vec<_>>>()
406 })
407 .await?;
408 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
409 this.update(&mut cx, |this, cx| {
410 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
411 existing_context
412 } else {
413 this.register_context(&context, cx);
414 this.synchronize_contexts(cx);
415 context
416 }
417 })
418 })
419 }
420
421 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
422 let handle = if self.project_is_shared {
423 ContextHandle::Strong(context.clone())
424 } else {
425 ContextHandle::Weak(context.downgrade())
426 };
427 self.contexts.push(handle);
428 self.advertise_contexts(cx);
429 cx.subscribe(context, Self::handle_context_event).detach();
430 }
431
432 fn handle_context_event(
433 &mut self,
434 context: Model<Context>,
435 event: &ContextEvent,
436 cx: &mut ModelContext<Self>,
437 ) {
438 let Some(project_id) = self.project.read(cx).remote_id() else {
439 return;
440 };
441
442 match event {
443 ContextEvent::SummaryChanged => {
444 self.advertise_contexts(cx);
445 }
446 ContextEvent::Operation(operation) => {
447 let context_id = context.read(cx).id().to_proto();
448 let operation = operation.to_proto();
449 self.client
450 .send(proto::UpdateContext {
451 project_id,
452 context_id,
453 operation: Some(operation),
454 })
455 .log_err();
456 }
457 _ => {}
458 }
459 }
460
461 fn advertise_contexts(&self, cx: &AppContext) {
462 let Some(project_id) = self.project.read(cx).remote_id() else {
463 return;
464 };
465
466 // For now, only the host can advertise their open contexts.
467 if self.project.read(cx).is_remote() {
468 return;
469 }
470
471 let contexts = self
472 .contexts
473 .iter()
474 .rev()
475 .filter_map(|context| {
476 let context = context.upgrade()?.read(cx);
477 if context.replica_id() == ReplicaId::default() {
478 Some(proto::ContextMetadata {
479 context_id: context.id().to_proto(),
480 summary: context.summary().map(|summary| summary.text.clone()),
481 })
482 } else {
483 None
484 }
485 })
486 .collect();
487 self.client
488 .send(proto::AdvertiseContexts {
489 project_id,
490 contexts,
491 })
492 .ok();
493 }
494
495 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
496 let Some(project_id) = self.project.read(cx).remote_id() else {
497 return;
498 };
499
500 let contexts = self
501 .contexts
502 .iter()
503 .filter_map(|context| {
504 let context = context.upgrade()?.read(cx);
505 if context.replica_id() != ReplicaId::default() {
506 Some(context.version(cx).to_proto(context.id().clone()))
507 } else {
508 None
509 }
510 })
511 .collect();
512
513 let client = self.client.clone();
514 let request = self.client.request(proto::SynchronizeContexts {
515 project_id,
516 contexts,
517 });
518 cx.spawn(|this, cx| async move {
519 let response = request.await?;
520
521 let mut context_ids = Vec::new();
522 let mut operations = Vec::new();
523 this.read_with(&cx, |this, cx| {
524 for context_version_proto in response.contexts {
525 let context_version = ContextVersion::from_proto(&context_version_proto);
526 let context_id = ContextId::from_proto(context_version_proto.context_id);
527 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
528 context_ids.push(context_id);
529 operations.push(context.read(cx).serialize_ops(&context_version, cx));
530 }
531 }
532 })?;
533
534 let operations = futures::future::join_all(operations).await;
535 for (context_id, operations) in context_ids.into_iter().zip(operations) {
536 for operation in operations {
537 client.send(proto::UpdateContext {
538 project_id,
539 context_id: context_id.to_proto(),
540 operation: Some(operation),
541 })?;
542 }
543 }
544
545 anyhow::Ok(())
546 })
547 .detach_and_log_err(cx);
548 }
549
550 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
551 let metadata = self.contexts_metadata.clone();
552 let executor = cx.background_executor().clone();
553 cx.background_executor().spawn(async move {
554 if query.is_empty() {
555 metadata
556 } else {
557 let candidates = metadata
558 .iter()
559 .enumerate()
560 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
561 .collect::<Vec<_>>();
562 let matches = fuzzy::match_strings(
563 &candidates,
564 &query,
565 false,
566 100,
567 &Default::default(),
568 executor,
569 )
570 .await;
571
572 matches
573 .into_iter()
574 .map(|mat| metadata[mat.candidate_id].clone())
575 .collect()
576 }
577 })
578 }
579
580 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
581 &self.host_contexts
582 }
583
584 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
585 let fs = self.fs.clone();
586 cx.spawn(|this, mut cx| async move {
587 fs.create_dir(contexts_dir()).await?;
588
589 let mut paths = fs.read_dir(contexts_dir()).await?;
590 let mut contexts = Vec::<SavedContextMetadata>::new();
591 while let Some(path) = paths.next().await {
592 let path = path?;
593 if path.extension() != Some(OsStr::new("json")) {
594 continue;
595 }
596
597 let pattern = r" - \d+.zed.json$";
598 let re = Regex::new(pattern).unwrap();
599
600 let metadata = fs.metadata(&path).await?;
601 if let Some((file_name, metadata)) = path
602 .file_name()
603 .and_then(|name| name.to_str())
604 .zip(metadata)
605 {
606 // This is used to filter out contexts saved by the new assistant.
607 if !re.is_match(file_name) {
608 continue;
609 }
610
611 if let Some(title) = re.replace(file_name, "").lines().next() {
612 contexts.push(SavedContextMetadata {
613 title: title.to_string(),
614 path,
615 mtime: metadata.mtime.into(),
616 });
617 }
618 }
619 }
620 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
621
622 this.update(&mut cx, |this, cx| {
623 this.contexts_metadata = contexts;
624 cx.notify();
625 })
626 })
627 }
628}