1use crate::{
2 SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
3 TextThreadOperation, TextThreadVersion,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
15use itertools::Itertools;
16use language::LanguageRegistry;
17use paths::text_threads_dir;
18use project::{
19 Project,
20 context_server_store::{ContextServerStatus, ContextServerStore},
21};
22use prompt_store::PromptBuilder;
23use regex::Regex;
24use rpc::AnyProtoClient;
25use std::sync::LazyLock;
26use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
27use util::{ResultExt, TryFutureExt};
28use zed_env_vars::ZED_STATELESS;
29
30pub(crate) fn init(client: &AnyProtoClient) {
31 client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
32 client.add_entity_request_handler(TextThreadStore::handle_open_context);
33 client.add_entity_request_handler(TextThreadStore::handle_create_context);
34 client.add_entity_message_handler(TextThreadStore::handle_update_context);
35 client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
36}
37
38#[derive(Clone)]
39pub struct RemoteTextThreadMetadata {
40 pub id: TextThreadId,
41 pub summary: Option<String>,
42}
43
44pub struct TextThreadStore {
45 text_threads: Vec<TextThreadHandle>,
46 text_threads_metadata: Vec<SavedTextThreadMetadata>,
47 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
48 host_text_threads: Vec<RemoteTextThreadMetadata>,
49 fs: Arc<dyn Fs>,
50 languages: Arc<LanguageRegistry>,
51 slash_commands: Arc<SlashCommandWorkingSet>,
52 _watch_updates: Task<Option<()>>,
53 client: Arc<Client>,
54 project: WeakEntity<Project>,
55 project_is_shared: bool,
56 client_subscription: Option<client::Subscription>,
57 _project_subscriptions: Vec<gpui::Subscription>,
58 prompt_builder: Arc<PromptBuilder>,
59}
60
61enum TextThreadHandle {
62 Weak(WeakEntity<TextThread>),
63 Strong(Entity<TextThread>),
64}
65
66impl TextThreadHandle {
67 fn upgrade(&self) -> Option<Entity<TextThread>> {
68 match self {
69 TextThreadHandle::Weak(weak) => weak.upgrade(),
70 TextThreadHandle::Strong(strong) => Some(strong.clone()),
71 }
72 }
73
74 fn downgrade(&self) -> WeakEntity<TextThread> {
75 match self {
76 TextThreadHandle::Weak(weak) => weak.clone(),
77 TextThreadHandle::Strong(strong) => strong.downgrade(),
78 }
79 }
80}
81
82impl TextThreadStore {
83 pub fn new(
84 project: Entity<Project>,
85 prompt_builder: Arc<PromptBuilder>,
86 slash_commands: Arc<SlashCommandWorkingSet>,
87 cx: &mut App,
88 ) -> Task<Result<Entity<Self>>> {
89 let fs = project.read(cx).fs().clone();
90 let languages = project.read(cx).languages().clone();
91 cx.spawn(async move |cx| {
92 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
93 let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
94
95 let this = cx.new(|cx: &mut Context<Self>| {
96 let mut this = Self {
97 text_threads: Vec::new(),
98 text_threads_metadata: Vec::new(),
99 context_server_slash_command_ids: HashMap::default(),
100 host_text_threads: Vec::new(),
101 fs,
102 languages,
103 slash_commands,
104 _watch_updates: cx.spawn(async move |this, cx| {
105 async move {
106 while events.next().await.is_some() {
107 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
108 }
109 anyhow::Ok(())
110 }
111 .log_err()
112 .await
113 }),
114 client_subscription: None,
115 _project_subscriptions: vec![
116 cx.subscribe(&project, Self::handle_project_event),
117 ],
118 project_is_shared: false,
119 client: project.read(cx).client(),
120 project: project.downgrade(),
121 prompt_builder,
122 };
123 this.handle_project_shared(cx);
124 this.synchronize_contexts(cx);
125 this.register_context_server_handlers(cx);
126 this.reload(cx).detach_and_log_err(cx);
127 this
128 });
129
130 Ok(this)
131 })
132 }
133
134 #[cfg(any(test, feature = "test-support"))]
135 pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
136 Self {
137 text_threads: Default::default(),
138 text_threads_metadata: Default::default(),
139 context_server_slash_command_ids: Default::default(),
140 host_text_threads: Default::default(),
141 fs: project.read(cx).fs().clone(),
142 languages: project.read(cx).languages().clone(),
143 slash_commands: Arc::default(),
144 _watch_updates: Task::ready(None),
145 client: project.read(cx).client(),
146 project: project.downgrade(),
147 project_is_shared: false,
148 client_subscription: None,
149 _project_subscriptions: Default::default(),
150 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
151 }
152 }
153
154 async fn handle_advertise_contexts(
155 this: Entity<Self>,
156 envelope: TypedEnvelope<proto::AdvertiseContexts>,
157 mut cx: AsyncApp,
158 ) -> Result<()> {
159 this.update(&mut cx, |this, cx| {
160 this.host_text_threads = envelope
161 .payload
162 .contexts
163 .into_iter()
164 .map(|text_thread| RemoteTextThreadMetadata {
165 id: TextThreadId::from_proto(text_thread.context_id),
166 summary: text_thread.summary,
167 })
168 .collect();
169 cx.notify();
170 });
171 Ok(())
172 }
173
174 async fn handle_open_context(
175 this: Entity<Self>,
176 envelope: TypedEnvelope<proto::OpenContext>,
177 mut cx: AsyncApp,
178 ) -> Result<proto::OpenContextResponse> {
179 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
180 let operations = this.update(&mut cx, |this, cx| {
181 let project = this.project.upgrade().context("project not found")?;
182
183 anyhow::ensure!(
184 !project.read(cx).is_via_collab(),
185 "only the host contexts can be opened"
186 );
187
188 let text_thread = this
189 .loaded_text_thread_for_id(&context_id, cx)
190 .context("context not found")?;
191 anyhow::ensure!(
192 text_thread.read(cx).replica_id() == ReplicaId::default(),
193 "context must be opened via the host"
194 );
195
196 anyhow::Ok(
197 text_thread
198 .read(cx)
199 .serialize_ops(&TextThreadVersion::default(), cx),
200 )
201 })?;
202 let operations = operations.await;
203 Ok(proto::OpenContextResponse {
204 context: Some(proto::Context { operations }),
205 })
206 }
207
208 async fn handle_create_context(
209 this: Entity<Self>,
210 _: TypedEnvelope<proto::CreateContext>,
211 mut cx: AsyncApp,
212 ) -> Result<proto::CreateContextResponse> {
213 let (context_id, operations) = this.update(&mut cx, |this, cx| {
214 let project = this.project.upgrade().context("project not found")?;
215 anyhow::ensure!(
216 !project.read(cx).is_via_collab(),
217 "can only create contexts as the host"
218 );
219
220 let text_thread = this.create(cx);
221 let context_id = text_thread.read(cx).id().clone();
222
223 anyhow::Ok((
224 context_id,
225 text_thread
226 .read(cx)
227 .serialize_ops(&TextThreadVersion::default(), cx),
228 ))
229 })?;
230 let operations = operations.await;
231 Ok(proto::CreateContextResponse {
232 context_id: context_id.to_proto(),
233 context: Some(proto::Context { operations }),
234 })
235 }
236
237 async fn handle_update_context(
238 this: Entity<Self>,
239 envelope: TypedEnvelope<proto::UpdateContext>,
240 mut cx: AsyncApp,
241 ) -> Result<()> {
242 this.update(&mut cx, |this, cx| {
243 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
244 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
245 let operation_proto = envelope.payload.operation.context("invalid operation")?;
246 let operation = TextThreadOperation::from_proto(operation_proto)?;
247 text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
248 }
249 Ok(())
250 })
251 }
252
253 async fn handle_synchronize_contexts(
254 this: Entity<Self>,
255 envelope: TypedEnvelope<proto::SynchronizeContexts>,
256 mut cx: AsyncApp,
257 ) -> Result<proto::SynchronizeContextsResponse> {
258 this.update(&mut cx, |this, cx| {
259 let project = this.project.upgrade().context("project not found")?;
260 anyhow::ensure!(
261 !project.read(cx).is_via_collab(),
262 "only the host can synchronize contexts"
263 );
264
265 let mut local_versions = Vec::new();
266 for remote_version_proto in envelope.payload.contexts {
267 let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
268 let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
269 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
270 let text_thread = text_thread.read(cx);
271 let operations = text_thread.serialize_ops(&remote_version, cx);
272 local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
273 let client = this.client.clone();
274 let project_id = envelope.payload.project_id;
275 cx.background_spawn(async move {
276 let operations = operations.await;
277 for operation in operations {
278 client.send(proto::UpdateContext {
279 project_id,
280 context_id: context_id.to_proto(),
281 operation: Some(operation),
282 })?;
283 }
284 anyhow::Ok(())
285 })
286 .detach_and_log_err(cx);
287 }
288 }
289
290 this.advertise_contexts(cx);
291
292 anyhow::Ok(proto::SynchronizeContextsResponse {
293 contexts: local_versions,
294 })
295 })
296 }
297
298 fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
299 let Some(project) = self.project.upgrade() else {
300 return;
301 };
302
303 let is_shared = project.read(cx).is_shared();
304 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
305 if is_shared == was_shared {
306 return;
307 }
308
309 if is_shared {
310 self.text_threads.retain_mut(|text_thread| {
311 if let Some(strong_context) = text_thread.upgrade() {
312 *text_thread = TextThreadHandle::Strong(strong_context);
313 true
314 } else {
315 false
316 }
317 });
318 let remote_id = project.read(cx).remote_id().unwrap();
319 self.client_subscription = self
320 .client
321 .subscribe_to_entity(remote_id)
322 .log_err()
323 .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
324 self.advertise_contexts(cx);
325 } else {
326 self.client_subscription = None;
327 }
328 }
329
330 fn handle_project_event(
331 &mut self,
332 _project: Entity<Project>,
333 event: &project::Event,
334 cx: &mut Context<Self>,
335 ) {
336 match event {
337 project::Event::RemoteIdChanged(_) => {
338 self.handle_project_shared(cx);
339 }
340 project::Event::Reshared => {
341 self.advertise_contexts(cx);
342 }
343 project::Event::HostReshared | project::Event::Rejoined => {
344 self.synchronize_contexts(cx);
345 }
346 project::Event::DisconnectedFromHost => {
347 self.text_threads.retain_mut(|text_thread| {
348 if let Some(strong_context) = text_thread.upgrade() {
349 *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
350 strong_context.update(cx, |text_thread, cx| {
351 if text_thread.replica_id() != ReplicaId::default() {
352 text_thread.set_capability(language::Capability::ReadOnly, cx);
353 }
354 });
355 true
356 } else {
357 false
358 }
359 });
360 self.host_text_threads.clear();
361 cx.notify();
362 }
363 _ => {}
364 }
365 }
366
367 /// Returns saved threads ordered by `mtime` descending (newest first).
368 pub fn ordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
369 self.text_threads_metadata
370 .iter()
371 .sorted_by(|a, b| b.mtime.cmp(&a.mtime))
372 }
373
374 pub fn has_saved_text_threads(&self) -> bool {
375 !self.text_threads_metadata.is_empty()
376 }
377
378 pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
379 self.host_text_threads.iter()
380 }
381
382 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
383 let context = cx.new(|cx| {
384 TextThread::local(
385 self.languages.clone(),
386 Some(self.project.clone()),
387 self.prompt_builder.clone(),
388 self.slash_commands.clone(),
389 cx,
390 )
391 });
392 self.register_text_thread(&context, cx);
393 context
394 }
395
396 pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
397 let Some(project) = self.project.upgrade() else {
398 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
399 };
400 let project = project.read(cx);
401 let Some(project_id) = project.remote_id() else {
402 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
403 };
404
405 let replica_id = project.replica_id();
406 let capability = project.capability();
407 let language_registry = self.languages.clone();
408 let project = self.project.clone();
409
410 let prompt_builder = self.prompt_builder.clone();
411 let slash_commands = self.slash_commands.clone();
412 let request = self.client.request(proto::CreateContext { project_id });
413 cx.spawn(async move |this, cx| {
414 let response = request.await?;
415 let context_id = TextThreadId::from_proto(response.context_id);
416 let context_proto = response.context.context("invalid context")?;
417 let text_thread = cx.new(|cx| {
418 TextThread::new(
419 context_id.clone(),
420 replica_id,
421 capability,
422 language_registry,
423 prompt_builder,
424 slash_commands,
425 Some(project),
426 cx,
427 )
428 });
429 let operations = cx
430 .background_spawn(async move {
431 context_proto
432 .operations
433 .into_iter()
434 .map(TextThreadOperation::from_proto)
435 .collect::<Result<Vec<_>>>()
436 })
437 .await?;
438 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
439 this.update(cx, |this, cx| {
440 if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
441 existing_context
442 } else {
443 this.register_text_thread(&text_thread, cx);
444 this.synchronize_contexts(cx);
445 text_thread
446 }
447 })
448 })
449 }
450
451 pub fn open_local(
452 &mut self,
453 path: Arc<Path>,
454 cx: &Context<Self>,
455 ) -> Task<Result<Entity<TextThread>>> {
456 if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
457 return Task::ready(Ok(existing_context));
458 }
459
460 let fs = self.fs.clone();
461 let languages = self.languages.clone();
462 let project = self.project.clone();
463 let load = cx.background_spawn({
464 let path = path.clone();
465 async move {
466 let saved_context = fs.load(&path).await?;
467 SavedTextThread::from_json(&saved_context)
468 }
469 });
470 let prompt_builder = self.prompt_builder.clone();
471 let slash_commands = self.slash_commands.clone();
472
473 cx.spawn(async move |this, cx| {
474 let saved_context = load.await?;
475 let context = cx.new(|cx| {
476 TextThread::deserialize(
477 saved_context,
478 path.clone(),
479 languages,
480 prompt_builder,
481 slash_commands,
482 Some(project),
483 cx,
484 )
485 });
486 this.update(cx, |this, cx| {
487 if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
488 existing_context
489 } else {
490 this.register_text_thread(&context, cx);
491 context
492 }
493 })
494 })
495 }
496
497 pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
498 let fs = self.fs.clone();
499
500 cx.spawn(async move |this, cx| {
501 fs.remove_file(
502 &path,
503 RemoveOptions {
504 recursive: false,
505 ignore_if_not_exists: true,
506 },
507 )
508 .await?;
509
510 this.update(cx, |this, cx| {
511 this.text_threads.retain(|text_thread| {
512 text_thread
513 .upgrade()
514 .and_then(|text_thread| text_thread.read(cx).path())
515 != Some(&path)
516 });
517 this.text_threads_metadata
518 .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
519 })?;
520
521 Ok(())
522 })
523 }
524
525 pub fn delete_all_local(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
526 let fs = self.fs.clone();
527 let paths = self
528 .text_threads_metadata
529 .iter()
530 .map(|metadata| metadata.path.clone())
531 .collect::<Vec<_>>();
532
533 cx.spawn(async move |this, cx| {
534 for path in paths {
535 fs.remove_file(
536 &path,
537 RemoveOptions {
538 recursive: false,
539 ignore_if_not_exists: true,
540 },
541 )
542 .await?;
543 }
544
545 this.update(cx, |this, cx| {
546 this.text_threads.clear();
547 this.text_threads_metadata.clear();
548 cx.notify();
549 })?;
550
551 Ok(())
552 })
553 }
554
555 fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
556 self.text_threads.iter().find_map(|text_thread| {
557 let text_thread = text_thread.upgrade()?;
558 if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
559 Some(text_thread)
560 } else {
561 None
562 }
563 })
564 }
565
566 pub fn loaded_text_thread_for_id(
567 &self,
568 id: &TextThreadId,
569 cx: &App,
570 ) -> Option<Entity<TextThread>> {
571 self.text_threads.iter().find_map(|text_thread| {
572 let text_thread = text_thread.upgrade()?;
573 if text_thread.read(cx).id() == id {
574 Some(text_thread)
575 } else {
576 None
577 }
578 })
579 }
580
581 pub fn open_remote(
582 &mut self,
583 text_thread_id: TextThreadId,
584 cx: &mut Context<Self>,
585 ) -> Task<Result<Entity<TextThread>>> {
586 let Some(project) = self.project.upgrade() else {
587 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
588 };
589 let project = project.read(cx);
590 let Some(project_id) = project.remote_id() else {
591 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
592 };
593
594 if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
595 return Task::ready(Ok(context));
596 }
597
598 let replica_id = project.replica_id();
599 let capability = project.capability();
600 let language_registry = self.languages.clone();
601 let project = self.project.clone();
602 let request = self.client.request(proto::OpenContext {
603 project_id,
604 context_id: text_thread_id.to_proto(),
605 });
606 let prompt_builder = self.prompt_builder.clone();
607 let slash_commands = self.slash_commands.clone();
608 cx.spawn(async move |this, cx| {
609 let response = request.await?;
610 let context_proto = response.context.context("invalid context")?;
611 let text_thread = cx.new(|cx| {
612 TextThread::new(
613 text_thread_id.clone(),
614 replica_id,
615 capability,
616 language_registry,
617 prompt_builder,
618 slash_commands,
619 Some(project),
620 cx,
621 )
622 });
623 let operations = cx
624 .background_spawn(async move {
625 context_proto
626 .operations
627 .into_iter()
628 .map(TextThreadOperation::from_proto)
629 .collect::<Result<Vec<_>>>()
630 })
631 .await?;
632 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
633 this.update(cx, |this, cx| {
634 if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
635 {
636 existing_context
637 } else {
638 this.register_text_thread(&text_thread, cx);
639 this.synchronize_contexts(cx);
640 text_thread
641 }
642 })
643 })
644 }
645
646 fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
647 let handle = if self.project_is_shared {
648 TextThreadHandle::Strong(text_thread.clone())
649 } else {
650 TextThreadHandle::Weak(text_thread.downgrade())
651 };
652 self.text_threads.push(handle);
653 self.advertise_contexts(cx);
654 cx.subscribe(text_thread, Self::handle_context_event)
655 .detach();
656 }
657
658 fn handle_context_event(
659 &mut self,
660 text_thread: Entity<TextThread>,
661 event: &TextThreadEvent,
662 cx: &mut Context<Self>,
663 ) {
664 let Some(project) = self.project.upgrade() else {
665 return;
666 };
667 let Some(project_id) = project.read(cx).remote_id() else {
668 return;
669 };
670
671 match event {
672 TextThreadEvent::SummaryChanged => {
673 self.advertise_contexts(cx);
674 }
675 TextThreadEvent::PathChanged { old_path, new_path } => {
676 if let Some(old_path) = old_path.as_ref() {
677 for metadata in &mut self.text_threads_metadata {
678 if &metadata.path == old_path {
679 metadata.path = new_path.clone();
680 break;
681 }
682 }
683 }
684 }
685 TextThreadEvent::Operation(operation) => {
686 let context_id = text_thread.read(cx).id().to_proto();
687 let operation = operation.to_proto();
688 self.client
689 .send(proto::UpdateContext {
690 project_id,
691 context_id,
692 operation: Some(operation),
693 })
694 .log_err();
695 }
696 _ => {}
697 }
698 }
699
700 fn advertise_contexts(&self, cx: &App) {
701 let Some(project) = self.project.upgrade() else {
702 return;
703 };
704 let Some(project_id) = project.read(cx).remote_id() else {
705 return;
706 };
707 // For now, only the host can advertise their open contexts.
708 if project.read(cx).is_via_collab() {
709 return;
710 }
711
712 let contexts = self
713 .text_threads
714 .iter()
715 .rev()
716 .filter_map(|text_thread| {
717 let text_thread = text_thread.upgrade()?.read(cx);
718 if text_thread.replica_id() == ReplicaId::default() {
719 Some(proto::ContextMetadata {
720 context_id: text_thread.id().to_proto(),
721 summary: text_thread
722 .summary()
723 .content()
724 .map(|summary| summary.text.clone()),
725 })
726 } else {
727 None
728 }
729 })
730 .collect();
731 self.client
732 .send(proto::AdvertiseContexts {
733 project_id,
734 contexts,
735 })
736 .ok();
737 }
738
739 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
740 let Some(project) = self.project.upgrade() else {
741 return;
742 };
743 let Some(project_id) = project.read(cx).remote_id() else {
744 return;
745 };
746
747 let text_threads = self
748 .text_threads
749 .iter()
750 .filter_map(|text_thread| {
751 let text_thread = text_thread.upgrade()?.read(cx);
752 if text_thread.replica_id() != ReplicaId::default() {
753 Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
754 } else {
755 None
756 }
757 })
758 .collect();
759
760 let client = self.client.clone();
761 let request = self.client.request(proto::SynchronizeContexts {
762 project_id,
763 contexts: text_threads,
764 });
765 cx.spawn(async move |this, cx| {
766 let response = request.await?;
767
768 let mut text_thread_ids = Vec::new();
769 let mut operations = Vec::new();
770 this.read_with(cx, |this, cx| {
771 for context_version_proto in response.contexts {
772 let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
773 let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
774 if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
775 text_thread_ids.push(text_thread_id);
776 operations
777 .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
778 }
779 }
780 })?;
781
782 let operations = futures::future::join_all(operations).await;
783 for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
784 for operation in operations {
785 client.send(proto::UpdateContext {
786 project_id,
787 context_id: context_id.to_proto(),
788 operation: Some(operation),
789 })?;
790 }
791 }
792
793 anyhow::Ok(())
794 })
795 .detach_and_log_err(cx);
796 }
797
798 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
799 let metadata = self.text_threads_metadata.clone();
800 let executor = cx.background_executor().clone();
801 cx.background_spawn(async move {
802 if query.is_empty() {
803 metadata
804 } else {
805 let candidates = metadata
806 .iter()
807 .enumerate()
808 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
809 .collect::<Vec<_>>();
810 let matches = fuzzy::match_strings(
811 &candidates,
812 &query,
813 false,
814 true,
815 100,
816 &Default::default(),
817 executor,
818 )
819 .await;
820
821 matches
822 .into_iter()
823 .map(|mat| metadata[mat.candidate_id].clone())
824 .collect()
825 }
826 })
827 }
828
829 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
830 let fs = self.fs.clone();
831 cx.spawn(async move |this, cx| {
832 if *ZED_STATELESS {
833 return Ok(());
834 }
835 fs.create_dir(text_threads_dir()).await?;
836
837 let mut paths = fs.read_dir(text_threads_dir()).await?;
838 let mut contexts = Vec::<SavedTextThreadMetadata>::new();
839 while let Some(path) = paths.next().await {
840 let path = path?;
841 if path.extension() != Some(OsStr::new("json")) {
842 continue;
843 }
844
845 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
846 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
847
848 let metadata = fs.metadata(&path).await?;
849 if let Some((file_name, metadata)) = path
850 .file_name()
851 .and_then(|name| name.to_str())
852 .zip(metadata)
853 {
854 // This is used to filter out contexts saved by the new assistant.
855 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
856 continue;
857 }
858
859 if let Some(title) = ASSISTANT_CONTEXT_REGEX
860 .replace(file_name, "")
861 .lines()
862 .next()
863 {
864 contexts.push(SavedTextThreadMetadata {
865 title: title.to_string().into(),
866 path: path.into(),
867 mtime: metadata.mtime.timestamp_for_user().into(),
868 });
869 }
870 }
871 }
872 contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
873
874 this.update(cx, |this, cx| {
875 this.text_threads_metadata = contexts;
876 cx.notify();
877 })
878 })
879 }
880
881 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
882 let Some(project) = self.project.upgrade() else {
883 return;
884 };
885 let context_server_store = project.read(cx).context_server_store();
886 cx.subscribe(&context_server_store, Self::handle_context_server_event)
887 .detach();
888
889 // Check for any servers that were already running before the handler was registered
890 for server in context_server_store.read(cx).running_servers() {
891 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
892 }
893 }
894
895 fn handle_context_server_event(
896 &mut self,
897 context_server_store: Entity<ContextServerStore>,
898 event: &project::context_server_store::Event,
899 cx: &mut Context<Self>,
900 ) {
901 match event {
902 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
903 match status {
904 ContextServerStatus::Running => {
905 self.load_context_server_slash_commands(
906 server_id.clone(),
907 context_server_store,
908 cx,
909 );
910 }
911 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
912 if let Some(slash_command_ids) =
913 self.context_server_slash_command_ids.remove(server_id)
914 {
915 self.slash_commands.remove(&slash_command_ids);
916 }
917 }
918 _ => {}
919 }
920 }
921 }
922 }
923
924 fn load_context_server_slash_commands(
925 &self,
926 server_id: ContextServerId,
927 context_server_store: Entity<ContextServerStore>,
928 cx: &mut Context<Self>,
929 ) {
930 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
931 return;
932 };
933 let slash_command_working_set = self.slash_commands.clone();
934 cx.spawn(async move |this, cx| {
935 let Some(protocol) = server.client() else {
936 return;
937 };
938
939 if protocol.capable(context_server::protocol::ServerCapability::Prompts)
940 && let Some(response) = protocol
941 .request::<context_server::types::requests::PromptsList>(())
942 .await
943 .log_err()
944 {
945 let slash_command_ids = response
946 .prompts
947 .into_iter()
948 .filter(assistant_slash_commands::acceptable_prompt)
949 .map(|prompt| {
950 log::info!("registering context server command: {:?}", prompt.name);
951 slash_command_working_set.insert(Arc::new(
952 assistant_slash_commands::ContextServerSlashCommand::new(
953 context_server_store.clone(),
954 server.id(),
955 prompt,
956 ),
957 ))
958 })
959 .collect::<Vec<_>>();
960
961 this.update(cx, |this, _cx| {
962 this.context_server_slash_command_ids
963 .insert(server_id.clone(), slash_command_ids);
964 })
965 .log_err();
966 }
967 })
968 .detach();
969 }
970}
971
972#[cfg(test)]
973mod tests {
974 use super::*;
975 use fs::FakeFs;
976 use language_model::LanguageModelRegistry;
977 use project::Project;
978 use serde_json::json;
979 use settings::SettingsStore;
980 use std::path::{Path, PathBuf};
981 use std::sync::Arc;
982
983 fn init_test(cx: &mut gpui::TestAppContext) {
984 cx.update(|cx| {
985 let settings_store = SettingsStore::test(cx);
986 prompt_store::init(cx);
987 LanguageModelRegistry::test(cx);
988 cx.set_global(settings_store);
989 });
990 }
991
992 #[gpui::test]
993 async fn ordered_text_threads_sort_by_mtime(cx: &mut gpui::TestAppContext) {
994 init_test(cx);
995
996 let fs = FakeFs::new(cx.background_executor.clone());
997 fs.insert_tree("/root", json!({})).await;
998
999 let project = Project::test(fs, [Path::new("/root")], cx).await;
1000 let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1001
1002 let now = chrono::Local::now();
1003 let older = SavedTextThreadMetadata {
1004 title: "older".into(),
1005 path: Arc::from(PathBuf::from("/root/older.zed.json")),
1006 mtime: now - chrono::TimeDelta::days(1),
1007 };
1008 let middle = SavedTextThreadMetadata {
1009 title: "middle".into(),
1010 path: Arc::from(PathBuf::from("/root/middle.zed.json")),
1011 mtime: now - chrono::TimeDelta::hours(1),
1012 };
1013 let newer = SavedTextThreadMetadata {
1014 title: "newer".into(),
1015 path: Arc::from(PathBuf::from("/root/newer.zed.json")),
1016 mtime: now,
1017 };
1018
1019 store.update(cx, |store, _| {
1020 store.text_threads_metadata = vec![middle, older, newer];
1021 });
1022
1023 let ordered = store.read_with(cx, |store, _| {
1024 store
1025 .ordered_text_threads()
1026 .map(|entry| entry.title.to_string())
1027 .collect::<Vec<_>>()
1028 });
1029
1030 assert_eq!(ordered, vec!["newer", "middle", "older"]);
1031 }
1032
1033 #[gpui::test]
1034 async fn has_saved_text_threads_reflects_metadata(cx: &mut gpui::TestAppContext) {
1035 init_test(cx);
1036
1037 let fs = FakeFs::new(cx.background_executor.clone());
1038 fs.insert_tree("/root", json!({})).await;
1039
1040 let project = Project::test(fs, [Path::new("/root")], cx).await;
1041 let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1042
1043 assert!(!store.read_with(cx, |store, _| store.has_saved_text_threads()));
1044
1045 store.update(cx, |store, _| {
1046 store.text_threads_metadata = vec![SavedTextThreadMetadata {
1047 title: "thread".into(),
1048 path: Arc::from(PathBuf::from("/root/thread.zed.json")),
1049 mtime: chrono::Local::now(),
1050 }];
1051 });
1052
1053 assert!(store.read_with(cx, |store, _| store.has_saved_text_threads()));
1054 }
1055
1056 #[gpui::test]
1057 async fn delete_all_local_clears_metadata_and_files(cx: &mut gpui::TestAppContext) {
1058 init_test(cx);
1059
1060 let fs = FakeFs::new(cx.background_executor.clone());
1061 fs.insert_tree("/root", json!({})).await;
1062
1063 let thread_a = PathBuf::from("/root/thread-a.zed.json");
1064 let thread_b = PathBuf::from("/root/thread-b.zed.json");
1065 fs.touch_path(&thread_a).await;
1066 fs.touch_path(&thread_b).await;
1067
1068 let project = Project::test(fs.clone(), [Path::new("/root")], cx).await;
1069 let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1070
1071 let now = chrono::Local::now();
1072 store.update(cx, |store, cx| {
1073 store.create(cx);
1074 store.text_threads_metadata = vec![
1075 SavedTextThreadMetadata {
1076 title: "thread-a".into(),
1077 path: Arc::from(thread_a.clone()),
1078 mtime: now,
1079 },
1080 SavedTextThreadMetadata {
1081 title: "thread-b".into(),
1082 path: Arc::from(thread_b.clone()),
1083 mtime: now - chrono::TimeDelta::seconds(1),
1084 },
1085 ];
1086 });
1087
1088 let task = store.update(cx, |store, cx| store.delete_all_local(cx));
1089 task.await.unwrap();
1090
1091 assert!(!store.read_with(cx, |store, _| store.has_saved_text_threads()));
1092 assert_eq!(store.read_with(cx, |store, _| store.text_threads.len()), 0);
1093 assert!(fs.metadata(&thread_a).await.unwrap().is_none());
1094 assert!(fs.metadata(&thread_b).await.unwrap().is_none());
1095 }
1096}