1use crate::{
2 SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
3 TextThreadOperation, TextThreadVersion,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
15use language::LanguageRegistry;
16use paths::text_threads_dir;
17use project::{
18 Project,
19 context_server_store::{ContextServerStatus, ContextServerStore},
20};
21use prompt_store::PromptBuilder;
22use regex::Regex;
23use rpc::AnyProtoClient;
24use std::sync::LazyLock;
25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
26use util::{ResultExt, TryFutureExt};
27use zed_env_vars::ZED_STATELESS;
28
29pub(crate) fn init(client: &AnyProtoClient) {
30 client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
31 client.add_entity_request_handler(TextThreadStore::handle_open_context);
32 client.add_entity_request_handler(TextThreadStore::handle_create_context);
33 client.add_entity_message_handler(TextThreadStore::handle_update_context);
34 client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
35}
36
37#[derive(Clone)]
38pub struct RemoteTextThreadMetadata {
39 pub id: TextThreadId,
40 pub summary: Option<String>,
41}
42
43pub struct TextThreadStore {
44 text_threads: Vec<TextThreadHandle>,
45 text_threads_metadata: Vec<SavedTextThreadMetadata>,
46 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
47 host_text_threads: Vec<RemoteTextThreadMetadata>,
48 fs: Arc<dyn Fs>,
49 languages: Arc<LanguageRegistry>,
50 slash_commands: Arc<SlashCommandWorkingSet>,
51 _watch_updates: Task<Option<()>>,
52 client: Arc<Client>,
53 project: WeakEntity<Project>,
54 project_is_shared: bool,
55 client_subscription: Option<client::Subscription>,
56 _project_subscriptions: Vec<gpui::Subscription>,
57 prompt_builder: Arc<PromptBuilder>,
58}
59
60enum TextThreadHandle {
61 Weak(WeakEntity<TextThread>),
62 Strong(Entity<TextThread>),
63}
64
65impl TextThreadHandle {
66 fn upgrade(&self) -> Option<Entity<TextThread>> {
67 match self {
68 TextThreadHandle::Weak(weak) => weak.upgrade(),
69 TextThreadHandle::Strong(strong) => Some(strong.clone()),
70 }
71 }
72
73 fn downgrade(&self) -> WeakEntity<TextThread> {
74 match self {
75 TextThreadHandle::Weak(weak) => weak.clone(),
76 TextThreadHandle::Strong(strong) => strong.downgrade(),
77 }
78 }
79}
80
81impl TextThreadStore {
82 pub fn new(
83 project: Entity<Project>,
84 prompt_builder: Arc<PromptBuilder>,
85 slash_commands: Arc<SlashCommandWorkingSet>,
86 cx: &mut App,
87 ) -> Task<Result<Entity<Self>>> {
88 let fs = project.read(cx).fs().clone();
89 let languages = project.read(cx).languages().clone();
90 cx.spawn(async move |cx| {
91 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
92 let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
93
94 let this = cx.new(|cx: &mut Context<Self>| {
95 let mut this = Self {
96 text_threads: Vec::new(),
97 text_threads_metadata: Vec::new(),
98 context_server_slash_command_ids: HashMap::default(),
99 host_text_threads: Vec::new(),
100 fs,
101 languages,
102 slash_commands,
103 _watch_updates: cx.spawn(async move |this, cx| {
104 async move {
105 while events.next().await.is_some() {
106 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
107 }
108 anyhow::Ok(())
109 }
110 .log_err()
111 .await
112 }),
113 client_subscription: None,
114 _project_subscriptions: vec![
115 cx.subscribe(&project, Self::handle_project_event),
116 ],
117 project_is_shared: false,
118 client: project.read(cx).client(),
119 project: project.downgrade(),
120 prompt_builder,
121 };
122 this.handle_project_shared(cx);
123 this.synchronize_contexts(cx);
124 this.register_context_server_handlers(cx);
125 this.reload(cx).detach_and_log_err(cx);
126 this
127 });
128
129 Ok(this)
130 })
131 }
132
133 #[cfg(any(test, feature = "test-support"))]
134 pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
135 Self {
136 text_threads: Default::default(),
137 text_threads_metadata: Default::default(),
138 context_server_slash_command_ids: Default::default(),
139 host_text_threads: Default::default(),
140 fs: project.read(cx).fs().clone(),
141 languages: project.read(cx).languages().clone(),
142 slash_commands: Arc::default(),
143 _watch_updates: Task::ready(None),
144 client: project.read(cx).client(),
145 project: project.downgrade(),
146 project_is_shared: false,
147 client_subscription: None,
148 _project_subscriptions: Default::default(),
149 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
150 }
151 }
152
153 async fn handle_advertise_contexts(
154 this: Entity<Self>,
155 envelope: TypedEnvelope<proto::AdvertiseContexts>,
156 mut cx: AsyncApp,
157 ) -> Result<()> {
158 this.update(&mut cx, |this, cx| {
159 this.host_text_threads = envelope
160 .payload
161 .contexts
162 .into_iter()
163 .map(|text_thread| RemoteTextThreadMetadata {
164 id: TextThreadId::from_proto(text_thread.context_id),
165 summary: text_thread.summary,
166 })
167 .collect();
168 cx.notify();
169 });
170 Ok(())
171 }
172
173 async fn handle_open_context(
174 this: Entity<Self>,
175 envelope: TypedEnvelope<proto::OpenContext>,
176 mut cx: AsyncApp,
177 ) -> Result<proto::OpenContextResponse> {
178 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
179 let operations = this.update(&mut cx, |this, cx| {
180 let project = this.project.upgrade().context("project not found")?;
181
182 anyhow::ensure!(
183 !project.read(cx).is_via_collab(),
184 "only the host contexts can be opened"
185 );
186
187 let text_thread = this
188 .loaded_text_thread_for_id(&context_id, cx)
189 .context("context not found")?;
190 anyhow::ensure!(
191 text_thread.read(cx).replica_id() == ReplicaId::default(),
192 "context must be opened via the host"
193 );
194
195 anyhow::Ok(
196 text_thread
197 .read(cx)
198 .serialize_ops(&TextThreadVersion::default(), cx),
199 )
200 })?;
201 let operations = operations.await;
202 Ok(proto::OpenContextResponse {
203 context: Some(proto::Context { operations }),
204 })
205 }
206
207 async fn handle_create_context(
208 this: Entity<Self>,
209 _: TypedEnvelope<proto::CreateContext>,
210 mut cx: AsyncApp,
211 ) -> Result<proto::CreateContextResponse> {
212 let (context_id, operations) = this.update(&mut cx, |this, cx| {
213 let project = this.project.upgrade().context("project not found")?;
214 anyhow::ensure!(
215 !project.read(cx).is_via_collab(),
216 "can only create contexts as the host"
217 );
218
219 let text_thread = this.create(cx);
220 let context_id = text_thread.read(cx).id().clone();
221
222 anyhow::Ok((
223 context_id,
224 text_thread
225 .read(cx)
226 .serialize_ops(&TextThreadVersion::default(), cx),
227 ))
228 })?;
229 let operations = operations.await;
230 Ok(proto::CreateContextResponse {
231 context_id: context_id.to_proto(),
232 context: Some(proto::Context { operations }),
233 })
234 }
235
236 async fn handle_update_context(
237 this: Entity<Self>,
238 envelope: TypedEnvelope<proto::UpdateContext>,
239 mut cx: AsyncApp,
240 ) -> Result<()> {
241 this.update(&mut cx, |this, cx| {
242 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
243 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
244 let operation_proto = envelope.payload.operation.context("invalid operation")?;
245 let operation = TextThreadOperation::from_proto(operation_proto)?;
246 text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
247 }
248 Ok(())
249 })
250 }
251
252 async fn handle_synchronize_contexts(
253 this: Entity<Self>,
254 envelope: TypedEnvelope<proto::SynchronizeContexts>,
255 mut cx: AsyncApp,
256 ) -> Result<proto::SynchronizeContextsResponse> {
257 this.update(&mut cx, |this, cx| {
258 let project = this.project.upgrade().context("project not found")?;
259 anyhow::ensure!(
260 !project.read(cx).is_via_collab(),
261 "only the host can synchronize contexts"
262 );
263
264 let mut local_versions = Vec::new();
265 for remote_version_proto in envelope.payload.contexts {
266 let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
267 let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
268 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
269 let text_thread = text_thread.read(cx);
270 let operations = text_thread.serialize_ops(&remote_version, cx);
271 local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
272 let client = this.client.clone();
273 let project_id = envelope.payload.project_id;
274 cx.background_spawn(async move {
275 let operations = operations.await;
276 for operation in operations {
277 client.send(proto::UpdateContext {
278 project_id,
279 context_id: context_id.to_proto(),
280 operation: Some(operation),
281 })?;
282 }
283 anyhow::Ok(())
284 })
285 .detach_and_log_err(cx);
286 }
287 }
288
289 this.advertise_contexts(cx);
290
291 anyhow::Ok(proto::SynchronizeContextsResponse {
292 contexts: local_versions,
293 })
294 })
295 }
296
297 fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
298 let Some(project) = self.project.upgrade() else {
299 return;
300 };
301
302 let is_shared = project.read(cx).is_shared();
303 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
304 if is_shared == was_shared {
305 return;
306 }
307
308 if is_shared {
309 self.text_threads.retain_mut(|text_thread| {
310 if let Some(strong_context) = text_thread.upgrade() {
311 *text_thread = TextThreadHandle::Strong(strong_context);
312 true
313 } else {
314 false
315 }
316 });
317 let remote_id = project.read(cx).remote_id().unwrap();
318 self.client_subscription = self
319 .client
320 .subscribe_to_entity(remote_id)
321 .log_err()
322 .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
323 self.advertise_contexts(cx);
324 } else {
325 self.client_subscription = None;
326 }
327 }
328
329 fn handle_project_event(
330 &mut self,
331 _project: Entity<Project>,
332 event: &project::Event,
333 cx: &mut Context<Self>,
334 ) {
335 match event {
336 project::Event::RemoteIdChanged(_) => {
337 self.handle_project_shared(cx);
338 }
339 project::Event::Reshared => {
340 self.advertise_contexts(cx);
341 }
342 project::Event::HostReshared | project::Event::Rejoined => {
343 self.synchronize_contexts(cx);
344 }
345 project::Event::DisconnectedFromHost => {
346 self.text_threads.retain_mut(|text_thread| {
347 if let Some(strong_context) = text_thread.upgrade() {
348 *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
349 strong_context.update(cx, |text_thread, cx| {
350 if text_thread.replica_id() != ReplicaId::default() {
351 text_thread.set_capability(language::Capability::ReadOnly, cx);
352 }
353 });
354 true
355 } else {
356 false
357 }
358 });
359 self.host_text_threads.clear();
360 cx.notify();
361 }
362 _ => {}
363 }
364 }
365
366 pub fn unordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
367 self.text_threads_metadata.iter()
368 }
369
370 pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
371 self.host_text_threads.iter()
372 }
373
374 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
375 let context = cx.new(|cx| {
376 TextThread::local(
377 self.languages.clone(),
378 Some(self.project.clone()),
379 self.prompt_builder.clone(),
380 self.slash_commands.clone(),
381 cx,
382 )
383 });
384 self.register_text_thread(&context, cx);
385 context
386 }
387
388 pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
389 let Some(project) = self.project.upgrade() else {
390 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
391 };
392 let project = project.read(cx);
393 let Some(project_id) = project.remote_id() else {
394 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
395 };
396
397 let replica_id = project.replica_id();
398 let capability = project.capability();
399 let language_registry = self.languages.clone();
400 let project = self.project.clone();
401
402 let prompt_builder = self.prompt_builder.clone();
403 let slash_commands = self.slash_commands.clone();
404 let request = self.client.request(proto::CreateContext { project_id });
405 cx.spawn(async move |this, cx| {
406 let response = request.await?;
407 let context_id = TextThreadId::from_proto(response.context_id);
408 let context_proto = response.context.context("invalid context")?;
409 let text_thread = cx.new(|cx| {
410 TextThread::new(
411 context_id.clone(),
412 replica_id,
413 capability,
414 language_registry,
415 prompt_builder,
416 slash_commands,
417 Some(project),
418 cx,
419 )
420 });
421 let operations = cx
422 .background_spawn(async move {
423 context_proto
424 .operations
425 .into_iter()
426 .map(TextThreadOperation::from_proto)
427 .collect::<Result<Vec<_>>>()
428 })
429 .await?;
430 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
431 this.update(cx, |this, cx| {
432 if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
433 existing_context
434 } else {
435 this.register_text_thread(&text_thread, cx);
436 this.synchronize_contexts(cx);
437 text_thread
438 }
439 })
440 })
441 }
442
443 pub fn open_local(
444 &mut self,
445 path: Arc<Path>,
446 cx: &Context<Self>,
447 ) -> Task<Result<Entity<TextThread>>> {
448 if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
449 return Task::ready(Ok(existing_context));
450 }
451
452 let fs = self.fs.clone();
453 let languages = self.languages.clone();
454 let project = self.project.clone();
455 let load = cx.background_spawn({
456 let path = path.clone();
457 async move {
458 let saved_context = fs.load(&path).await?;
459 SavedTextThread::from_json(&saved_context)
460 }
461 });
462 let prompt_builder = self.prompt_builder.clone();
463 let slash_commands = self.slash_commands.clone();
464
465 cx.spawn(async move |this, cx| {
466 let saved_context = load.await?;
467 let context = cx.new(|cx| {
468 TextThread::deserialize(
469 saved_context,
470 path.clone(),
471 languages,
472 prompt_builder,
473 slash_commands,
474 Some(project),
475 cx,
476 )
477 });
478 this.update(cx, |this, cx| {
479 if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
480 existing_context
481 } else {
482 this.register_text_thread(&context, cx);
483 context
484 }
485 })
486 })
487 }
488
489 pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
490 let fs = self.fs.clone();
491
492 cx.spawn(async move |this, cx| {
493 fs.remove_file(
494 &path,
495 RemoveOptions {
496 recursive: false,
497 ignore_if_not_exists: true,
498 },
499 )
500 .await?;
501
502 this.update(cx, |this, cx| {
503 this.text_threads.retain(|text_thread| {
504 text_thread
505 .upgrade()
506 .and_then(|text_thread| text_thread.read(cx).path())
507 != Some(&path)
508 });
509 this.text_threads_metadata
510 .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
511 })?;
512
513 Ok(())
514 })
515 }
516
517 fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
518 self.text_threads.iter().find_map(|text_thread| {
519 let text_thread = text_thread.upgrade()?;
520 if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
521 Some(text_thread)
522 } else {
523 None
524 }
525 })
526 }
527
528 pub fn loaded_text_thread_for_id(
529 &self,
530 id: &TextThreadId,
531 cx: &App,
532 ) -> Option<Entity<TextThread>> {
533 self.text_threads.iter().find_map(|text_thread| {
534 let text_thread = text_thread.upgrade()?;
535 if text_thread.read(cx).id() == id {
536 Some(text_thread)
537 } else {
538 None
539 }
540 })
541 }
542
543 pub fn open_remote(
544 &mut self,
545 text_thread_id: TextThreadId,
546 cx: &mut Context<Self>,
547 ) -> Task<Result<Entity<TextThread>>> {
548 let Some(project) = self.project.upgrade() else {
549 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
550 };
551 let project = project.read(cx);
552 let Some(project_id) = project.remote_id() else {
553 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
554 };
555
556 if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
557 return Task::ready(Ok(context));
558 }
559
560 let replica_id = project.replica_id();
561 let capability = project.capability();
562 let language_registry = self.languages.clone();
563 let project = self.project.clone();
564 let request = self.client.request(proto::OpenContext {
565 project_id,
566 context_id: text_thread_id.to_proto(),
567 });
568 let prompt_builder = self.prompt_builder.clone();
569 let slash_commands = self.slash_commands.clone();
570 cx.spawn(async move |this, cx| {
571 let response = request.await?;
572 let context_proto = response.context.context("invalid context")?;
573 let text_thread = cx.new(|cx| {
574 TextThread::new(
575 text_thread_id.clone(),
576 replica_id,
577 capability,
578 language_registry,
579 prompt_builder,
580 slash_commands,
581 Some(project),
582 cx,
583 )
584 });
585 let operations = cx
586 .background_spawn(async move {
587 context_proto
588 .operations
589 .into_iter()
590 .map(TextThreadOperation::from_proto)
591 .collect::<Result<Vec<_>>>()
592 })
593 .await?;
594 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
595 this.update(cx, |this, cx| {
596 if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
597 {
598 existing_context
599 } else {
600 this.register_text_thread(&text_thread, cx);
601 this.synchronize_contexts(cx);
602 text_thread
603 }
604 })
605 })
606 }
607
608 fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
609 let handle = if self.project_is_shared {
610 TextThreadHandle::Strong(text_thread.clone())
611 } else {
612 TextThreadHandle::Weak(text_thread.downgrade())
613 };
614 self.text_threads.push(handle);
615 self.advertise_contexts(cx);
616 cx.subscribe(text_thread, Self::handle_context_event)
617 .detach();
618 }
619
620 fn handle_context_event(
621 &mut self,
622 text_thread: Entity<TextThread>,
623 event: &TextThreadEvent,
624 cx: &mut Context<Self>,
625 ) {
626 let Some(project) = self.project.upgrade() else {
627 return;
628 };
629 let Some(project_id) = project.read(cx).remote_id() else {
630 return;
631 };
632
633 match event {
634 TextThreadEvent::SummaryChanged => {
635 self.advertise_contexts(cx);
636 }
637 TextThreadEvent::PathChanged { old_path, new_path } => {
638 if let Some(old_path) = old_path.as_ref() {
639 for metadata in &mut self.text_threads_metadata {
640 if &metadata.path == old_path {
641 metadata.path = new_path.clone();
642 break;
643 }
644 }
645 }
646 }
647 TextThreadEvent::Operation(operation) => {
648 let context_id = text_thread.read(cx).id().to_proto();
649 let operation = operation.to_proto();
650 self.client
651 .send(proto::UpdateContext {
652 project_id,
653 context_id,
654 operation: Some(operation),
655 })
656 .log_err();
657 }
658 _ => {}
659 }
660 }
661
662 fn advertise_contexts(&self, cx: &App) {
663 let Some(project) = self.project.upgrade() else {
664 return;
665 };
666 let Some(project_id) = project.read(cx).remote_id() else {
667 return;
668 };
669 // For now, only the host can advertise their open contexts.
670 if project.read(cx).is_via_collab() {
671 return;
672 }
673
674 let contexts = self
675 .text_threads
676 .iter()
677 .rev()
678 .filter_map(|text_thread| {
679 let text_thread = text_thread.upgrade()?.read(cx);
680 if text_thread.replica_id() == ReplicaId::default() {
681 Some(proto::ContextMetadata {
682 context_id: text_thread.id().to_proto(),
683 summary: text_thread
684 .summary()
685 .content()
686 .map(|summary| summary.text.clone()),
687 })
688 } else {
689 None
690 }
691 })
692 .collect();
693 self.client
694 .send(proto::AdvertiseContexts {
695 project_id,
696 contexts,
697 })
698 .ok();
699 }
700
701 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
702 let Some(project) = self.project.upgrade() else {
703 return;
704 };
705 let Some(project_id) = project.read(cx).remote_id() else {
706 return;
707 };
708
709 let text_threads = self
710 .text_threads
711 .iter()
712 .filter_map(|text_thread| {
713 let text_thread = text_thread.upgrade()?.read(cx);
714 if text_thread.replica_id() != ReplicaId::default() {
715 Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
716 } else {
717 None
718 }
719 })
720 .collect();
721
722 let client = self.client.clone();
723 let request = self.client.request(proto::SynchronizeContexts {
724 project_id,
725 contexts: text_threads,
726 });
727 cx.spawn(async move |this, cx| {
728 let response = request.await?;
729
730 let mut text_thread_ids = Vec::new();
731 let mut operations = Vec::new();
732 this.read_with(cx, |this, cx| {
733 for context_version_proto in response.contexts {
734 let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
735 let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
736 if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
737 text_thread_ids.push(text_thread_id);
738 operations
739 .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
740 }
741 }
742 })?;
743
744 let operations = futures::future::join_all(operations).await;
745 for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
746 for operation in operations {
747 client.send(proto::UpdateContext {
748 project_id,
749 context_id: context_id.to_proto(),
750 operation: Some(operation),
751 })?;
752 }
753 }
754
755 anyhow::Ok(())
756 })
757 .detach_and_log_err(cx);
758 }
759
760 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
761 let metadata = self.text_threads_metadata.clone();
762 let executor = cx.background_executor().clone();
763 cx.background_spawn(async move {
764 if query.is_empty() {
765 metadata
766 } else {
767 let candidates = metadata
768 .iter()
769 .enumerate()
770 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
771 .collect::<Vec<_>>();
772 let matches = fuzzy::match_strings(
773 &candidates,
774 &query,
775 false,
776 true,
777 100,
778 &Default::default(),
779 executor,
780 )
781 .await;
782
783 matches
784 .into_iter()
785 .map(|mat| metadata[mat.candidate_id].clone())
786 .collect()
787 }
788 })
789 }
790
791 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
792 let fs = self.fs.clone();
793 cx.spawn(async move |this, cx| {
794 if *ZED_STATELESS {
795 return Ok(());
796 }
797 fs.create_dir(text_threads_dir()).await?;
798
799 let mut paths = fs.read_dir(text_threads_dir()).await?;
800 let mut contexts = Vec::<SavedTextThreadMetadata>::new();
801 while let Some(path) = paths.next().await {
802 let path = path?;
803 if path.extension() != Some(OsStr::new("json")) {
804 continue;
805 }
806
807 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
808 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
809
810 let metadata = fs.metadata(&path).await?;
811 if let Some((file_name, metadata)) = path
812 .file_name()
813 .and_then(|name| name.to_str())
814 .zip(metadata)
815 {
816 // This is used to filter out contexts saved by the new assistant.
817 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
818 continue;
819 }
820
821 if let Some(title) = ASSISTANT_CONTEXT_REGEX
822 .replace(file_name, "")
823 .lines()
824 .next()
825 {
826 contexts.push(SavedTextThreadMetadata {
827 title: title.to_string().into(),
828 path: path.into(),
829 mtime: metadata.mtime.timestamp_for_user().into(),
830 });
831 }
832 }
833 }
834 contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
835
836 this.update(cx, |this, cx| {
837 this.text_threads_metadata = contexts;
838 cx.notify();
839 })
840 })
841 }
842
843 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
844 let Some(project) = self.project.upgrade() else {
845 return;
846 };
847 let context_server_store = project.read(cx).context_server_store();
848 cx.subscribe(&context_server_store, Self::handle_context_server_event)
849 .detach();
850
851 // Check for any servers that were already running before the handler was registered
852 for server in context_server_store.read(cx).running_servers() {
853 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
854 }
855 }
856
857 fn handle_context_server_event(
858 &mut self,
859 context_server_store: Entity<ContextServerStore>,
860 event: &project::context_server_store::Event,
861 cx: &mut Context<Self>,
862 ) {
863 match event {
864 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
865 match status {
866 ContextServerStatus::Running => {
867 self.load_context_server_slash_commands(
868 server_id.clone(),
869 context_server_store,
870 cx,
871 );
872 }
873 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
874 if let Some(slash_command_ids) =
875 self.context_server_slash_command_ids.remove(server_id)
876 {
877 self.slash_commands.remove(&slash_command_ids);
878 }
879 }
880 _ => {}
881 }
882 }
883 }
884 }
885
886 fn load_context_server_slash_commands(
887 &self,
888 server_id: ContextServerId,
889 context_server_store: Entity<ContextServerStore>,
890 cx: &mut Context<Self>,
891 ) {
892 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
893 return;
894 };
895 let slash_command_working_set = self.slash_commands.clone();
896 cx.spawn(async move |this, cx| {
897 let Some(protocol) = server.client() else {
898 return;
899 };
900
901 if protocol.capable(context_server::protocol::ServerCapability::Prompts)
902 && let Some(response) = protocol
903 .request::<context_server::types::requests::PromptsList>(())
904 .await
905 .log_err()
906 {
907 let slash_command_ids = response
908 .prompts
909 .into_iter()
910 .filter(assistant_slash_commands::acceptable_prompt)
911 .map(|prompt| {
912 log::info!("registering context server command: {:?}", prompt.name);
913 slash_command_working_set.insert(Arc::new(
914 assistant_slash_commands::ContextServerSlashCommand::new(
915 context_server_store.clone(),
916 server.id(),
917 prompt,
918 ),
919 ))
920 })
921 .collect::<Vec<_>>();
922
923 this.update(cx, |this, _cx| {
924 this.context_server_slash_command_ids
925 .insert(server_id.clone(), slash_command_ids);
926 })
927 .log_err();
928 }
929 })
930 .detach();
931 }
932}