1use crate::{
2 SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
3 TextThreadOperation, TextThreadVersion,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
15use language::LanguageRegistry;
16use paths::text_threads_dir;
17use project::{
18 Project,
19 context_server_store::{ContextServerStatus, ContextServerStore},
20};
21use prompt_store::PromptBuilder;
22use regex::Regex;
23use rpc::AnyProtoClient;
24use std::sync::LazyLock;
25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
26use util::{ResultExt, TryFutureExt};
27use zed_env_vars::ZED_STATELESS;
28
29pub(crate) fn init(client: &AnyProtoClient) {
30 client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
31 client.add_entity_request_handler(TextThreadStore::handle_open_context);
32 client.add_entity_request_handler(TextThreadStore::handle_create_context);
33 client.add_entity_message_handler(TextThreadStore::handle_update_context);
34 client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
35}
36
37#[derive(Clone)]
38pub struct RemoteTextThreadMetadata {
39 pub id: TextThreadId,
40 pub summary: Option<String>,
41}
42
43pub struct TextThreadStore {
44 text_threads: Vec<TextThreadHandle>,
45 text_threads_metadata: Vec<SavedTextThreadMetadata>,
46 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
47 host_text_threads: Vec<RemoteTextThreadMetadata>,
48 fs: Arc<dyn Fs>,
49 languages: Arc<LanguageRegistry>,
50 slash_commands: Arc<SlashCommandWorkingSet>,
51 _watch_updates: Task<Option<()>>,
52 client: Arc<Client>,
53 project: WeakEntity<Project>,
54 project_is_shared: bool,
55 client_subscription: Option<client::Subscription>,
56 _project_subscriptions: Vec<gpui::Subscription>,
57 prompt_builder: Arc<PromptBuilder>,
58}
59
60enum TextThreadHandle {
61 Weak(WeakEntity<TextThread>),
62 Strong(Entity<TextThread>),
63}
64
65impl TextThreadHandle {
66 fn upgrade(&self) -> Option<Entity<TextThread>> {
67 match self {
68 TextThreadHandle::Weak(weak) => weak.upgrade(),
69 TextThreadHandle::Strong(strong) => Some(strong.clone()),
70 }
71 }
72
73 fn downgrade(&self) -> WeakEntity<TextThread> {
74 match self {
75 TextThreadHandle::Weak(weak) => weak.clone(),
76 TextThreadHandle::Strong(strong) => strong.downgrade(),
77 }
78 }
79}
80
81impl TextThreadStore {
82 pub fn new(
83 project: Entity<Project>,
84 prompt_builder: Arc<PromptBuilder>,
85 slash_commands: Arc<SlashCommandWorkingSet>,
86 cx: &mut App,
87 ) -> Task<Result<Entity<Self>>> {
88 let fs = project.read(cx).fs().clone();
89 let languages = project.read(cx).languages().clone();
90 cx.spawn(async move |cx| {
91 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
92 let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
93
94 let this = cx.new(|cx: &mut Context<Self>| {
95 let mut this = Self {
96 text_threads: Vec::new(),
97 text_threads_metadata: Vec::new(),
98 context_server_slash_command_ids: HashMap::default(),
99 host_text_threads: Vec::new(),
100 fs,
101 languages,
102 slash_commands,
103 _watch_updates: cx.spawn(async move |this, cx| {
104 async move {
105 while events.next().await.is_some() {
106 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
107 }
108 anyhow::Ok(())
109 }
110 .log_err()
111 .await
112 }),
113 client_subscription: None,
114 _project_subscriptions: vec![
115 cx.subscribe(&project, Self::handle_project_event),
116 ],
117 project_is_shared: false,
118 client: project.read(cx).client(),
119 project: project.downgrade(),
120 prompt_builder,
121 };
122 this.handle_project_shared(cx);
123 this.synchronize_contexts(cx);
124 this.register_context_server_handlers(cx);
125 this.reload(cx).detach_and_log_err(cx);
126 this
127 })?;
128
129 Ok(this)
130 })
131 }
132
133 #[cfg(any(test, feature = "test-support"))]
134 pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
135 Self {
136 text_threads: Default::default(),
137 text_threads_metadata: Default::default(),
138 context_server_slash_command_ids: Default::default(),
139 host_text_threads: Default::default(),
140 fs: project.read(cx).fs().clone(),
141 languages: project.read(cx).languages().clone(),
142 slash_commands: Arc::default(),
143 _watch_updates: Task::ready(None),
144 client: project.read(cx).client(),
145 project: project.downgrade(),
146 project_is_shared: false,
147 client_subscription: None,
148 _project_subscriptions: Default::default(),
149 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
150 }
151 }
152
153 async fn handle_advertise_contexts(
154 this: Entity<Self>,
155 envelope: TypedEnvelope<proto::AdvertiseContexts>,
156 mut cx: AsyncApp,
157 ) -> Result<()> {
158 this.update(&mut cx, |this, cx| {
159 this.host_text_threads = envelope
160 .payload
161 .contexts
162 .into_iter()
163 .map(|text_thread| RemoteTextThreadMetadata {
164 id: TextThreadId::from_proto(text_thread.context_id),
165 summary: text_thread.summary,
166 })
167 .collect();
168 cx.notify();
169 })
170 }
171
172 async fn handle_open_context(
173 this: Entity<Self>,
174 envelope: TypedEnvelope<proto::OpenContext>,
175 mut cx: AsyncApp,
176 ) -> Result<proto::OpenContextResponse> {
177 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
178 let operations = this.update(&mut cx, |this, cx| {
179 let project = this.project.upgrade().context("project not found")?;
180
181 anyhow::ensure!(
182 !project.read(cx).is_via_collab(),
183 "only the host contexts can be opened"
184 );
185
186 let text_thread = this
187 .loaded_text_thread_for_id(&context_id, cx)
188 .context("context not found")?;
189 anyhow::ensure!(
190 text_thread.read(cx).replica_id() == ReplicaId::default(),
191 "context must be opened via the host"
192 );
193
194 anyhow::Ok(
195 text_thread
196 .read(cx)
197 .serialize_ops(&TextThreadVersion::default(), cx),
198 )
199 })??;
200 let operations = operations.await;
201 Ok(proto::OpenContextResponse {
202 context: Some(proto::Context { operations }),
203 })
204 }
205
206 async fn handle_create_context(
207 this: Entity<Self>,
208 _: TypedEnvelope<proto::CreateContext>,
209 mut cx: AsyncApp,
210 ) -> Result<proto::CreateContextResponse> {
211 let (context_id, operations) = this.update(&mut cx, |this, cx| {
212 let project = this.project.upgrade().context("project not found")?;
213 anyhow::ensure!(
214 !project.read(cx).is_via_collab(),
215 "can only create contexts as the host"
216 );
217
218 let text_thread = this.create(cx);
219 let context_id = text_thread.read(cx).id().clone();
220
221 anyhow::Ok((
222 context_id,
223 text_thread
224 .read(cx)
225 .serialize_ops(&TextThreadVersion::default(), cx),
226 ))
227 })??;
228 let operations = operations.await;
229 Ok(proto::CreateContextResponse {
230 context_id: context_id.to_proto(),
231 context: Some(proto::Context { operations }),
232 })
233 }
234
235 async fn handle_update_context(
236 this: Entity<Self>,
237 envelope: TypedEnvelope<proto::UpdateContext>,
238 mut cx: AsyncApp,
239 ) -> Result<()> {
240 this.update(&mut cx, |this, cx| {
241 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
242 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
243 let operation_proto = envelope.payload.operation.context("invalid operation")?;
244 let operation = TextThreadOperation::from_proto(operation_proto)?;
245 text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
246 }
247 Ok(())
248 })?
249 }
250
251 async fn handle_synchronize_contexts(
252 this: Entity<Self>,
253 envelope: TypedEnvelope<proto::SynchronizeContexts>,
254 mut cx: AsyncApp,
255 ) -> Result<proto::SynchronizeContextsResponse> {
256 this.update(&mut cx, |this, cx| {
257 let project = this.project.upgrade().context("project not found")?;
258 anyhow::ensure!(
259 !project.read(cx).is_via_collab(),
260 "only the host can synchronize contexts"
261 );
262
263 let mut local_versions = Vec::new();
264 for remote_version_proto in envelope.payload.contexts {
265 let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
266 let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
267 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
268 let text_thread = text_thread.read(cx);
269 let operations = text_thread.serialize_ops(&remote_version, cx);
270 local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
271 let client = this.client.clone();
272 let project_id = envelope.payload.project_id;
273 cx.background_spawn(async move {
274 let operations = operations.await;
275 for operation in operations {
276 client.send(proto::UpdateContext {
277 project_id,
278 context_id: context_id.to_proto(),
279 operation: Some(operation),
280 })?;
281 }
282 anyhow::Ok(())
283 })
284 .detach_and_log_err(cx);
285 }
286 }
287
288 this.advertise_contexts(cx);
289
290 anyhow::Ok(proto::SynchronizeContextsResponse {
291 contexts: local_versions,
292 })
293 })?
294 }
295
296 fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
297 let Some(project) = self.project.upgrade() else {
298 return;
299 };
300
301 let is_shared = project.read(cx).is_shared();
302 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
303 if is_shared == was_shared {
304 return;
305 }
306
307 if is_shared {
308 self.text_threads.retain_mut(|text_thread| {
309 if let Some(strong_context) = text_thread.upgrade() {
310 *text_thread = TextThreadHandle::Strong(strong_context);
311 true
312 } else {
313 false
314 }
315 });
316 let remote_id = project.read(cx).remote_id().unwrap();
317 self.client_subscription = self
318 .client
319 .subscribe_to_entity(remote_id)
320 .log_err()
321 .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
322 self.advertise_contexts(cx);
323 } else {
324 self.client_subscription = None;
325 }
326 }
327
328 fn handle_project_event(
329 &mut self,
330 _project: Entity<Project>,
331 event: &project::Event,
332 cx: &mut Context<Self>,
333 ) {
334 match event {
335 project::Event::RemoteIdChanged(_) => {
336 self.handle_project_shared(cx);
337 }
338 project::Event::Reshared => {
339 self.advertise_contexts(cx);
340 }
341 project::Event::HostReshared | project::Event::Rejoined => {
342 self.synchronize_contexts(cx);
343 }
344 project::Event::DisconnectedFromHost => {
345 self.text_threads.retain_mut(|text_thread| {
346 if let Some(strong_context) = text_thread.upgrade() {
347 *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
348 strong_context.update(cx, |text_thread, cx| {
349 if text_thread.replica_id() != ReplicaId::default() {
350 text_thread.set_capability(language::Capability::ReadOnly, cx);
351 }
352 });
353 true
354 } else {
355 false
356 }
357 });
358 self.host_text_threads.clear();
359 cx.notify();
360 }
361 _ => {}
362 }
363 }
364
365 pub fn unordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
366 self.text_threads_metadata.iter()
367 }
368
369 pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
370 self.host_text_threads.iter()
371 }
372
373 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
374 let context = cx.new(|cx| {
375 TextThread::local(
376 self.languages.clone(),
377 Some(self.project.clone()),
378 self.prompt_builder.clone(),
379 self.slash_commands.clone(),
380 cx,
381 )
382 });
383 self.register_text_thread(&context, cx);
384 context
385 }
386
387 pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
388 let Some(project) = self.project.upgrade() else {
389 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
390 };
391 let project = project.read(cx);
392 let Some(project_id) = project.remote_id() else {
393 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
394 };
395
396 let replica_id = project.replica_id();
397 let capability = project.capability();
398 let language_registry = self.languages.clone();
399 let project = self.project.clone();
400
401 let prompt_builder = self.prompt_builder.clone();
402 let slash_commands = self.slash_commands.clone();
403 let request = self.client.request(proto::CreateContext { project_id });
404 cx.spawn(async move |this, cx| {
405 let response = request.await?;
406 let context_id = TextThreadId::from_proto(response.context_id);
407 let context_proto = response.context.context("invalid context")?;
408 let text_thread = cx.new(|cx| {
409 TextThread::new(
410 context_id.clone(),
411 replica_id,
412 capability,
413 language_registry,
414 prompt_builder,
415 slash_commands,
416 Some(project),
417 cx,
418 )
419 })?;
420 let operations = cx
421 .background_spawn(async move {
422 context_proto
423 .operations
424 .into_iter()
425 .map(TextThreadOperation::from_proto)
426 .collect::<Result<Vec<_>>>()
427 })
428 .await?;
429 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
430 this.update(cx, |this, cx| {
431 if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
432 existing_context
433 } else {
434 this.register_text_thread(&text_thread, cx);
435 this.synchronize_contexts(cx);
436 text_thread
437 }
438 })
439 })
440 }
441
442 pub fn open_local(
443 &mut self,
444 path: Arc<Path>,
445 cx: &Context<Self>,
446 ) -> Task<Result<Entity<TextThread>>> {
447 if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
448 return Task::ready(Ok(existing_context));
449 }
450
451 let fs = self.fs.clone();
452 let languages = self.languages.clone();
453 let project = self.project.clone();
454 let load = cx.background_spawn({
455 let path = path.clone();
456 async move {
457 let saved_context = fs.load(&path).await?;
458 SavedTextThread::from_json(&saved_context)
459 }
460 });
461 let prompt_builder = self.prompt_builder.clone();
462 let slash_commands = self.slash_commands.clone();
463
464 cx.spawn(async move |this, cx| {
465 let saved_context = load.await?;
466 let context = cx.new(|cx| {
467 TextThread::deserialize(
468 saved_context,
469 path.clone(),
470 languages,
471 prompt_builder,
472 slash_commands,
473 Some(project),
474 cx,
475 )
476 })?;
477 this.update(cx, |this, cx| {
478 if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
479 existing_context
480 } else {
481 this.register_text_thread(&context, cx);
482 context
483 }
484 })
485 })
486 }
487
488 pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
489 let fs = self.fs.clone();
490
491 cx.spawn(async move |this, cx| {
492 fs.remove_file(
493 &path,
494 RemoveOptions {
495 recursive: false,
496 ignore_if_not_exists: true,
497 },
498 )
499 .await?;
500
501 this.update(cx, |this, cx| {
502 this.text_threads.retain(|text_thread| {
503 text_thread
504 .upgrade()
505 .and_then(|text_thread| text_thread.read(cx).path())
506 != Some(&path)
507 });
508 this.text_threads_metadata
509 .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
510 })?;
511
512 Ok(())
513 })
514 }
515
516 fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
517 self.text_threads.iter().find_map(|text_thread| {
518 let text_thread = text_thread.upgrade()?;
519 if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
520 Some(text_thread)
521 } else {
522 None
523 }
524 })
525 }
526
527 pub fn loaded_text_thread_for_id(
528 &self,
529 id: &TextThreadId,
530 cx: &App,
531 ) -> Option<Entity<TextThread>> {
532 self.text_threads.iter().find_map(|text_thread| {
533 let text_thread = text_thread.upgrade()?;
534 if text_thread.read(cx).id() == id {
535 Some(text_thread)
536 } else {
537 None
538 }
539 })
540 }
541
542 pub fn open_remote(
543 &mut self,
544 text_thread_id: TextThreadId,
545 cx: &mut Context<Self>,
546 ) -> Task<Result<Entity<TextThread>>> {
547 let Some(project) = self.project.upgrade() else {
548 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
549 };
550 let project = project.read(cx);
551 let Some(project_id) = project.remote_id() else {
552 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
553 };
554
555 if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
556 return Task::ready(Ok(context));
557 }
558
559 let replica_id = project.replica_id();
560 let capability = project.capability();
561 let language_registry = self.languages.clone();
562 let project = self.project.clone();
563 let request = self.client.request(proto::OpenContext {
564 project_id,
565 context_id: text_thread_id.to_proto(),
566 });
567 let prompt_builder = self.prompt_builder.clone();
568 let slash_commands = self.slash_commands.clone();
569 cx.spawn(async move |this, cx| {
570 let response = request.await?;
571 let context_proto = response.context.context("invalid context")?;
572 let text_thread = cx.new(|cx| {
573 TextThread::new(
574 text_thread_id.clone(),
575 replica_id,
576 capability,
577 language_registry,
578 prompt_builder,
579 slash_commands,
580 Some(project),
581 cx,
582 )
583 })?;
584 let operations = cx
585 .background_spawn(async move {
586 context_proto
587 .operations
588 .into_iter()
589 .map(TextThreadOperation::from_proto)
590 .collect::<Result<Vec<_>>>()
591 })
592 .await?;
593 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
594 this.update(cx, |this, cx| {
595 if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
596 {
597 existing_context
598 } else {
599 this.register_text_thread(&text_thread, cx);
600 this.synchronize_contexts(cx);
601 text_thread
602 }
603 })
604 })
605 }
606
607 fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
608 let handle = if self.project_is_shared {
609 TextThreadHandle::Strong(text_thread.clone())
610 } else {
611 TextThreadHandle::Weak(text_thread.downgrade())
612 };
613 self.text_threads.push(handle);
614 self.advertise_contexts(cx);
615 cx.subscribe(text_thread, Self::handle_context_event)
616 .detach();
617 }
618
619 fn handle_context_event(
620 &mut self,
621 text_thread: Entity<TextThread>,
622 event: &TextThreadEvent,
623 cx: &mut Context<Self>,
624 ) {
625 let Some(project) = self.project.upgrade() else {
626 return;
627 };
628 let Some(project_id) = project.read(cx).remote_id() else {
629 return;
630 };
631
632 match event {
633 TextThreadEvent::SummaryChanged => {
634 self.advertise_contexts(cx);
635 }
636 TextThreadEvent::PathChanged { old_path, new_path } => {
637 if let Some(old_path) = old_path.as_ref() {
638 for metadata in &mut self.text_threads_metadata {
639 if &metadata.path == old_path {
640 metadata.path = new_path.clone();
641 break;
642 }
643 }
644 }
645 }
646 TextThreadEvent::Operation(operation) => {
647 let context_id = text_thread.read(cx).id().to_proto();
648 let operation = operation.to_proto();
649 self.client
650 .send(proto::UpdateContext {
651 project_id,
652 context_id,
653 operation: Some(operation),
654 })
655 .log_err();
656 }
657 _ => {}
658 }
659 }
660
661 fn advertise_contexts(&self, cx: &App) {
662 let Some(project) = self.project.upgrade() else {
663 return;
664 };
665 let Some(project_id) = project.read(cx).remote_id() else {
666 return;
667 };
668 // For now, only the host can advertise their open contexts.
669 if project.read(cx).is_via_collab() {
670 return;
671 }
672
673 let contexts = self
674 .text_threads
675 .iter()
676 .rev()
677 .filter_map(|text_thread| {
678 let text_thread = text_thread.upgrade()?.read(cx);
679 if text_thread.replica_id() == ReplicaId::default() {
680 Some(proto::ContextMetadata {
681 context_id: text_thread.id().to_proto(),
682 summary: text_thread
683 .summary()
684 .content()
685 .map(|summary| summary.text.clone()),
686 })
687 } else {
688 None
689 }
690 })
691 .collect();
692 self.client
693 .send(proto::AdvertiseContexts {
694 project_id,
695 contexts,
696 })
697 .ok();
698 }
699
700 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
701 let Some(project) = self.project.upgrade() else {
702 return;
703 };
704 let Some(project_id) = project.read(cx).remote_id() else {
705 return;
706 };
707
708 let text_threads = self
709 .text_threads
710 .iter()
711 .filter_map(|text_thread| {
712 let text_thread = text_thread.upgrade()?.read(cx);
713 if text_thread.replica_id() != ReplicaId::default() {
714 Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
715 } else {
716 None
717 }
718 })
719 .collect();
720
721 let client = self.client.clone();
722 let request = self.client.request(proto::SynchronizeContexts {
723 project_id,
724 contexts: text_threads,
725 });
726 cx.spawn(async move |this, cx| {
727 let response = request.await?;
728
729 let mut text_thread_ids = Vec::new();
730 let mut operations = Vec::new();
731 this.read_with(cx, |this, cx| {
732 for context_version_proto in response.contexts {
733 let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
734 let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
735 if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
736 text_thread_ids.push(text_thread_id);
737 operations
738 .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
739 }
740 }
741 })?;
742
743 let operations = futures::future::join_all(operations).await;
744 for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
745 for operation in operations {
746 client.send(proto::UpdateContext {
747 project_id,
748 context_id: context_id.to_proto(),
749 operation: Some(operation),
750 })?;
751 }
752 }
753
754 anyhow::Ok(())
755 })
756 .detach_and_log_err(cx);
757 }
758
759 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
760 let metadata = self.text_threads_metadata.clone();
761 let executor = cx.background_executor().clone();
762 cx.background_spawn(async move {
763 if query.is_empty() {
764 metadata
765 } else {
766 let candidates = metadata
767 .iter()
768 .enumerate()
769 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
770 .collect::<Vec<_>>();
771 let matches = fuzzy::match_strings(
772 &candidates,
773 &query,
774 false,
775 true,
776 100,
777 &Default::default(),
778 executor,
779 )
780 .await;
781
782 matches
783 .into_iter()
784 .map(|mat| metadata[mat.candidate_id].clone())
785 .collect()
786 }
787 })
788 }
789
790 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
791 let fs = self.fs.clone();
792 cx.spawn(async move |this, cx| {
793 if *ZED_STATELESS {
794 return Ok(());
795 }
796 fs.create_dir(text_threads_dir()).await?;
797
798 let mut paths = fs.read_dir(text_threads_dir()).await?;
799 let mut contexts = Vec::<SavedTextThreadMetadata>::new();
800 while let Some(path) = paths.next().await {
801 let path = path?;
802 if path.extension() != Some(OsStr::new("json")) {
803 continue;
804 }
805
806 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
807 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
808
809 let metadata = fs.metadata(&path).await?;
810 if let Some((file_name, metadata)) = path
811 .file_name()
812 .and_then(|name| name.to_str())
813 .zip(metadata)
814 {
815 // This is used to filter out contexts saved by the new assistant.
816 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
817 continue;
818 }
819
820 if let Some(title) = ASSISTANT_CONTEXT_REGEX
821 .replace(file_name, "")
822 .lines()
823 .next()
824 {
825 contexts.push(SavedTextThreadMetadata {
826 title: title.to_string().into(),
827 path: path.into(),
828 mtime: metadata.mtime.timestamp_for_user().into(),
829 });
830 }
831 }
832 }
833 contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
834
835 this.update(cx, |this, cx| {
836 this.text_threads_metadata = contexts;
837 cx.notify();
838 })
839 })
840 }
841
842 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
843 let Some(project) = self.project.upgrade() else {
844 return;
845 };
846 let context_server_store = project.read(cx).context_server_store();
847 cx.subscribe(&context_server_store, Self::handle_context_server_event)
848 .detach();
849
850 // Check for any servers that were already running before the handler was registered
851 for server in context_server_store.read(cx).running_servers() {
852 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
853 }
854 }
855
856 fn handle_context_server_event(
857 &mut self,
858 context_server_store: Entity<ContextServerStore>,
859 event: &project::context_server_store::Event,
860 cx: &mut Context<Self>,
861 ) {
862 match event {
863 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
864 match status {
865 ContextServerStatus::Running => {
866 self.load_context_server_slash_commands(
867 server_id.clone(),
868 context_server_store,
869 cx,
870 );
871 }
872 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
873 if let Some(slash_command_ids) =
874 self.context_server_slash_command_ids.remove(server_id)
875 {
876 self.slash_commands.remove(&slash_command_ids);
877 }
878 }
879 _ => {}
880 }
881 }
882 }
883 }
884
885 fn load_context_server_slash_commands(
886 &self,
887 server_id: ContextServerId,
888 context_server_store: Entity<ContextServerStore>,
889 cx: &mut Context<Self>,
890 ) {
891 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
892 return;
893 };
894 let slash_command_working_set = self.slash_commands.clone();
895 cx.spawn(async move |this, cx| {
896 let Some(protocol) = server.client() else {
897 return;
898 };
899
900 if protocol.capable(context_server::protocol::ServerCapability::Prompts)
901 && let Some(response) = protocol
902 .request::<context_server::types::requests::PromptsList>(())
903 .await
904 .log_err()
905 {
906 let slash_command_ids = response
907 .prompts
908 .into_iter()
909 .filter(assistant_slash_commands::acceptable_prompt)
910 .map(|prompt| {
911 log::info!("registering context server command: {:?}", prompt.name);
912 slash_command_working_set.insert(Arc::new(
913 assistant_slash_commands::ContextServerSlashCommand::new(
914 context_server_store.clone(),
915 server.id(),
916 prompt,
917 ),
918 ))
919 })
920 .collect::<Vec<_>>();
921
922 this.update(cx, |this, _cx| {
923 this.context_server_slash_command_ids
924 .insert(server_id.clone(), slash_command_ids);
925 })
926 .log_err();
927 }
928 })
929 .detach();
930 }
931}