1use crate::{
2 SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
3 TextThreadOperation, TextThreadVersion,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
15use language::LanguageRegistry;
16use paths::text_threads_dir;
17use project::{
18 Project,
19 context_server_store::{ContextServerStatus, ContextServerStore},
20};
21use prompt_store::PromptBuilder;
22use regex::Regex;
23use rpc::AnyProtoClient;
24use std::sync::LazyLock;
25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
26use util::{ResultExt, TryFutureExt};
27use zed_env_vars::ZED_STATELESS;
28
29pub(crate) fn init(client: &AnyProtoClient) {
30 client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
31 client.add_entity_request_handler(TextThreadStore::handle_open_context);
32 client.add_entity_request_handler(TextThreadStore::handle_create_context);
33 client.add_entity_message_handler(TextThreadStore::handle_update_context);
34 client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
35}
36
37#[derive(Clone)]
38pub struct RemoteTextThreadMetadata {
39 pub id: TextThreadId,
40 pub summary: Option<String>,
41}
42
43pub struct TextThreadStore {
44 text_threads: Vec<TextThreadHandle>,
45 text_threads_metadata: Vec<SavedTextThreadMetadata>,
46 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
47 host_text_threads: Vec<RemoteTextThreadMetadata>,
48 fs: Arc<dyn Fs>,
49 languages: Arc<LanguageRegistry>,
50 slash_commands: Arc<SlashCommandWorkingSet>,
51 telemetry: Arc<Telemetry>,
52 _watch_updates: Task<Option<()>>,
53 client: Arc<Client>,
54 project: Entity<Project>,
55 project_is_shared: bool,
56 client_subscription: Option<client::Subscription>,
57 _project_subscriptions: Vec<gpui::Subscription>,
58 prompt_builder: Arc<PromptBuilder>,
59}
60
61enum TextThreadHandle {
62 Weak(WeakEntity<TextThread>),
63 Strong(Entity<TextThread>),
64}
65
66impl TextThreadHandle {
67 fn upgrade(&self) -> Option<Entity<TextThread>> {
68 match self {
69 TextThreadHandle::Weak(weak) => weak.upgrade(),
70 TextThreadHandle::Strong(strong) => Some(strong.clone()),
71 }
72 }
73
74 fn downgrade(&self) -> WeakEntity<TextThread> {
75 match self {
76 TextThreadHandle::Weak(weak) => weak.clone(),
77 TextThreadHandle::Strong(strong) => strong.downgrade(),
78 }
79 }
80}
81
82impl TextThreadStore {
83 pub fn new(
84 project: Entity<Project>,
85 prompt_builder: Arc<PromptBuilder>,
86 slash_commands: Arc<SlashCommandWorkingSet>,
87 cx: &mut App,
88 ) -> Task<Result<Entity<Self>>> {
89 let fs = project.read(cx).fs().clone();
90 let languages = project.read(cx).languages().clone();
91 let telemetry = project.read(cx).client().telemetry().clone();
92 cx.spawn(async move |cx| {
93 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
94 let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
95
96 let this = cx.new(|cx: &mut Context<Self>| {
97 let mut this = Self {
98 text_threads: Vec::new(),
99 text_threads_metadata: Vec::new(),
100 context_server_slash_command_ids: HashMap::default(),
101 host_text_threads: Vec::new(),
102 fs,
103 languages,
104 slash_commands,
105 telemetry,
106 _watch_updates: cx.spawn(async move |this, cx| {
107 async move {
108 while events.next().await.is_some() {
109 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
110 }
111 anyhow::Ok(())
112 }
113 .log_err()
114 .await
115 }),
116 client_subscription: None,
117 _project_subscriptions: vec![
118 cx.subscribe(&project, Self::handle_project_event),
119 ],
120 project_is_shared: false,
121 client: project.read(cx).client(),
122 project: project.clone(),
123 prompt_builder,
124 };
125 this.handle_project_shared(project.clone(), cx);
126 this.synchronize_contexts(cx);
127 this.register_context_server_handlers(cx);
128 this.reload(cx).detach_and_log_err(cx);
129 this
130 })?;
131
132 Ok(this)
133 })
134 }
135
136 #[cfg(any(test, feature = "test-support"))]
137 pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
138 Self {
139 text_threads: Default::default(),
140 text_threads_metadata: Default::default(),
141 context_server_slash_command_ids: Default::default(),
142 host_text_threads: Default::default(),
143 fs: project.read(cx).fs().clone(),
144 languages: project.read(cx).languages().clone(),
145 slash_commands: Arc::default(),
146 telemetry: project.read(cx).client().telemetry().clone(),
147 _watch_updates: Task::ready(None),
148 client: project.read(cx).client(),
149 project,
150 project_is_shared: false,
151 client_subscription: None,
152 _project_subscriptions: Default::default(),
153 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
154 }
155 }
156
157 async fn handle_advertise_contexts(
158 this: Entity<Self>,
159 envelope: TypedEnvelope<proto::AdvertiseContexts>,
160 mut cx: AsyncApp,
161 ) -> Result<()> {
162 this.update(&mut cx, |this, cx| {
163 this.host_text_threads = envelope
164 .payload
165 .contexts
166 .into_iter()
167 .map(|text_thread| RemoteTextThreadMetadata {
168 id: TextThreadId::from_proto(text_thread.context_id),
169 summary: text_thread.summary,
170 })
171 .collect();
172 cx.notify();
173 })
174 }
175
176 async fn handle_open_context(
177 this: Entity<Self>,
178 envelope: TypedEnvelope<proto::OpenContext>,
179 mut cx: AsyncApp,
180 ) -> Result<proto::OpenContextResponse> {
181 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
182 let operations = this.update(&mut cx, |this, cx| {
183 anyhow::ensure!(
184 !this.project.read(cx).is_via_collab(),
185 "only the host contexts can be opened"
186 );
187
188 let text_thread = this
189 .loaded_text_thread_for_id(&context_id, cx)
190 .context("context not found")?;
191 anyhow::ensure!(
192 text_thread.read(cx).replica_id() == ReplicaId::default(),
193 "context must be opened via the host"
194 );
195
196 anyhow::Ok(
197 text_thread
198 .read(cx)
199 .serialize_ops(&TextThreadVersion::default(), cx),
200 )
201 })??;
202 let operations = operations.await;
203 Ok(proto::OpenContextResponse {
204 context: Some(proto::Context { operations }),
205 })
206 }
207
208 async fn handle_create_context(
209 this: Entity<Self>,
210 _: TypedEnvelope<proto::CreateContext>,
211 mut cx: AsyncApp,
212 ) -> Result<proto::CreateContextResponse> {
213 let (context_id, operations) = this.update(&mut cx, |this, cx| {
214 anyhow::ensure!(
215 !this.project.read(cx).is_via_collab(),
216 "can only create contexts as the host"
217 );
218
219 let text_thread = this.create(cx);
220 let context_id = text_thread.read(cx).id().clone();
221
222 anyhow::Ok((
223 context_id,
224 text_thread
225 .read(cx)
226 .serialize_ops(&TextThreadVersion::default(), cx),
227 ))
228 })??;
229 let operations = operations.await;
230 Ok(proto::CreateContextResponse {
231 context_id: context_id.to_proto(),
232 context: Some(proto::Context { operations }),
233 })
234 }
235
236 async fn handle_update_context(
237 this: Entity<Self>,
238 envelope: TypedEnvelope<proto::UpdateContext>,
239 mut cx: AsyncApp,
240 ) -> Result<()> {
241 this.update(&mut cx, |this, cx| {
242 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
243 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
244 let operation_proto = envelope.payload.operation.context("invalid operation")?;
245 let operation = TextThreadOperation::from_proto(operation_proto)?;
246 text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
247 }
248 Ok(())
249 })?
250 }
251
252 async fn handle_synchronize_contexts(
253 this: Entity<Self>,
254 envelope: TypedEnvelope<proto::SynchronizeContexts>,
255 mut cx: AsyncApp,
256 ) -> Result<proto::SynchronizeContextsResponse> {
257 this.update(&mut cx, |this, cx| {
258 anyhow::ensure!(
259 !this.project.read(cx).is_via_collab(),
260 "only the host can synchronize contexts"
261 );
262
263 let mut local_versions = Vec::new();
264 for remote_version_proto in envelope.payload.contexts {
265 let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
266 let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
267 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
268 let text_thread = text_thread.read(cx);
269 let operations = text_thread.serialize_ops(&remote_version, cx);
270 local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
271 let client = this.client.clone();
272 let project_id = envelope.payload.project_id;
273 cx.background_spawn(async move {
274 let operations = operations.await;
275 for operation in operations {
276 client.send(proto::UpdateContext {
277 project_id,
278 context_id: context_id.to_proto(),
279 operation: Some(operation),
280 })?;
281 }
282 anyhow::Ok(())
283 })
284 .detach_and_log_err(cx);
285 }
286 }
287
288 this.advertise_contexts(cx);
289
290 anyhow::Ok(proto::SynchronizeContextsResponse {
291 contexts: local_versions,
292 })
293 })?
294 }
295
296 fn handle_project_shared(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
297 let is_shared = self.project.read(cx).is_shared();
298 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
299 if is_shared == was_shared {
300 return;
301 }
302
303 if is_shared {
304 self.text_threads.retain_mut(|text_thread| {
305 if let Some(strong_context) = text_thread.upgrade() {
306 *text_thread = TextThreadHandle::Strong(strong_context);
307 true
308 } else {
309 false
310 }
311 });
312 let remote_id = self.project.read(cx).remote_id().unwrap();
313 self.client_subscription = self
314 .client
315 .subscribe_to_entity(remote_id)
316 .log_err()
317 .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
318 self.advertise_contexts(cx);
319 } else {
320 self.client_subscription = None;
321 }
322 }
323
324 fn handle_project_event(
325 &mut self,
326 project: Entity<Project>,
327 event: &project::Event,
328 cx: &mut Context<Self>,
329 ) {
330 match event {
331 project::Event::RemoteIdChanged(_) => {
332 self.handle_project_shared(project, cx);
333 }
334 project::Event::Reshared => {
335 self.advertise_contexts(cx);
336 }
337 project::Event::HostReshared | project::Event::Rejoined => {
338 self.synchronize_contexts(cx);
339 }
340 project::Event::DisconnectedFromHost => {
341 self.text_threads.retain_mut(|text_thread| {
342 if let Some(strong_context) = text_thread.upgrade() {
343 *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
344 strong_context.update(cx, |text_thread, cx| {
345 if text_thread.replica_id() != ReplicaId::default() {
346 text_thread.set_capability(language::Capability::ReadOnly, cx);
347 }
348 });
349 true
350 } else {
351 false
352 }
353 });
354 self.host_text_threads.clear();
355 cx.notify();
356 }
357 _ => {}
358 }
359 }
360
361 pub fn unordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
362 self.text_threads_metadata.iter()
363 }
364
365 pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
366 self.host_text_threads.iter()
367 }
368
369 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
370 let context = cx.new(|cx| {
371 TextThread::local(
372 self.languages.clone(),
373 Some(self.project.clone()),
374 Some(self.telemetry.clone()),
375 self.prompt_builder.clone(),
376 self.slash_commands.clone(),
377 cx,
378 )
379 });
380 self.register_text_thread(&context, cx);
381 context
382 }
383
384 pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
385 let project = self.project.read(cx);
386 let Some(project_id) = project.remote_id() else {
387 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
388 };
389
390 let replica_id = project.replica_id();
391 let capability = project.capability();
392 let language_registry = self.languages.clone();
393 let project = self.project.clone();
394 let telemetry = self.telemetry.clone();
395 let prompt_builder = self.prompt_builder.clone();
396 let slash_commands = self.slash_commands.clone();
397 let request = self.client.request(proto::CreateContext { project_id });
398 cx.spawn(async move |this, cx| {
399 let response = request.await?;
400 let context_id = TextThreadId::from_proto(response.context_id);
401 let context_proto = response.context.context("invalid context")?;
402 let text_thread = cx.new(|cx| {
403 TextThread::new(
404 context_id.clone(),
405 replica_id,
406 capability,
407 language_registry,
408 prompt_builder,
409 slash_commands,
410 Some(project),
411 Some(telemetry),
412 cx,
413 )
414 })?;
415 let operations = cx
416 .background_spawn(async move {
417 context_proto
418 .operations
419 .into_iter()
420 .map(TextThreadOperation::from_proto)
421 .collect::<Result<Vec<_>>>()
422 })
423 .await?;
424 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
425 this.update(cx, |this, cx| {
426 if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
427 existing_context
428 } else {
429 this.register_text_thread(&text_thread, cx);
430 this.synchronize_contexts(cx);
431 text_thread
432 }
433 })
434 })
435 }
436
437 pub fn open_local(
438 &mut self,
439 path: Arc<Path>,
440 cx: &Context<Self>,
441 ) -> Task<Result<Entity<TextThread>>> {
442 if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
443 return Task::ready(Ok(existing_context));
444 }
445
446 let fs = self.fs.clone();
447 let languages = self.languages.clone();
448 let project = self.project.clone();
449 let telemetry = self.telemetry.clone();
450 let load = cx.background_spawn({
451 let path = path.clone();
452 async move {
453 let saved_context = fs.load(&path).await?;
454 SavedTextThread::from_json(&saved_context)
455 }
456 });
457 let prompt_builder = self.prompt_builder.clone();
458 let slash_commands = self.slash_commands.clone();
459
460 cx.spawn(async move |this, cx| {
461 let saved_context = load.await?;
462 let context = cx.new(|cx| {
463 TextThread::deserialize(
464 saved_context,
465 path.clone(),
466 languages,
467 prompt_builder,
468 slash_commands,
469 Some(project),
470 Some(telemetry),
471 cx,
472 )
473 })?;
474 this.update(cx, |this, cx| {
475 if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
476 existing_context
477 } else {
478 this.register_text_thread(&context, cx);
479 context
480 }
481 })
482 })
483 }
484
485 pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
486 let fs = self.fs.clone();
487
488 cx.spawn(async move |this, cx| {
489 fs.remove_file(
490 &path,
491 RemoveOptions {
492 recursive: false,
493 ignore_if_not_exists: true,
494 },
495 )
496 .await?;
497
498 this.update(cx, |this, cx| {
499 this.text_threads.retain(|text_thread| {
500 text_thread
501 .upgrade()
502 .and_then(|text_thread| text_thread.read(cx).path())
503 != Some(&path)
504 });
505 this.text_threads_metadata
506 .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
507 })?;
508
509 Ok(())
510 })
511 }
512
513 fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
514 self.text_threads.iter().find_map(|text_thread| {
515 let text_thread = text_thread.upgrade()?;
516 if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
517 Some(text_thread)
518 } else {
519 None
520 }
521 })
522 }
523
524 pub fn loaded_text_thread_for_id(
525 &self,
526 id: &TextThreadId,
527 cx: &App,
528 ) -> Option<Entity<TextThread>> {
529 self.text_threads.iter().find_map(|text_thread| {
530 let text_thread = text_thread.upgrade()?;
531 if text_thread.read(cx).id() == id {
532 Some(text_thread)
533 } else {
534 None
535 }
536 })
537 }
538
539 pub fn open_remote(
540 &mut self,
541 text_thread_id: TextThreadId,
542 cx: &mut Context<Self>,
543 ) -> Task<Result<Entity<TextThread>>> {
544 let project = self.project.read(cx);
545 let Some(project_id) = project.remote_id() else {
546 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
547 };
548
549 if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
550 return Task::ready(Ok(context));
551 }
552
553 let replica_id = project.replica_id();
554 let capability = project.capability();
555 let language_registry = self.languages.clone();
556 let project = self.project.clone();
557 let telemetry = self.telemetry.clone();
558 let request = self.client.request(proto::OpenContext {
559 project_id,
560 context_id: text_thread_id.to_proto(),
561 });
562 let prompt_builder = self.prompt_builder.clone();
563 let slash_commands = self.slash_commands.clone();
564 cx.spawn(async move |this, cx| {
565 let response = request.await?;
566 let context_proto = response.context.context("invalid context")?;
567 let text_thread = cx.new(|cx| {
568 TextThread::new(
569 text_thread_id.clone(),
570 replica_id,
571 capability,
572 language_registry,
573 prompt_builder,
574 slash_commands,
575 Some(project),
576 Some(telemetry),
577 cx,
578 )
579 })?;
580 let operations = cx
581 .background_spawn(async move {
582 context_proto
583 .operations
584 .into_iter()
585 .map(TextThreadOperation::from_proto)
586 .collect::<Result<Vec<_>>>()
587 })
588 .await?;
589 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
590 this.update(cx, |this, cx| {
591 if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
592 {
593 existing_context
594 } else {
595 this.register_text_thread(&text_thread, cx);
596 this.synchronize_contexts(cx);
597 text_thread
598 }
599 })
600 })
601 }
602
603 fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
604 let handle = if self.project_is_shared {
605 TextThreadHandle::Strong(text_thread.clone())
606 } else {
607 TextThreadHandle::Weak(text_thread.downgrade())
608 };
609 self.text_threads.push(handle);
610 self.advertise_contexts(cx);
611 cx.subscribe(text_thread, Self::handle_context_event)
612 .detach();
613 }
614
615 fn handle_context_event(
616 &mut self,
617 text_thread: Entity<TextThread>,
618 event: &TextThreadEvent,
619 cx: &mut Context<Self>,
620 ) {
621 let Some(project_id) = self.project.read(cx).remote_id() else {
622 return;
623 };
624
625 match event {
626 TextThreadEvent::SummaryChanged => {
627 self.advertise_contexts(cx);
628 }
629 TextThreadEvent::PathChanged { old_path, new_path } => {
630 if let Some(old_path) = old_path.as_ref() {
631 for metadata in &mut self.text_threads_metadata {
632 if &metadata.path == old_path {
633 metadata.path = new_path.clone();
634 break;
635 }
636 }
637 }
638 }
639 TextThreadEvent::Operation(operation) => {
640 let context_id = text_thread.read(cx).id().to_proto();
641 let operation = operation.to_proto();
642 self.client
643 .send(proto::UpdateContext {
644 project_id,
645 context_id,
646 operation: Some(operation),
647 })
648 .log_err();
649 }
650 _ => {}
651 }
652 }
653
654 fn advertise_contexts(&self, cx: &App) {
655 let Some(project_id) = self.project.read(cx).remote_id() else {
656 return;
657 };
658
659 // For now, only the host can advertise their open contexts.
660 if self.project.read(cx).is_via_collab() {
661 return;
662 }
663
664 let contexts = self
665 .text_threads
666 .iter()
667 .rev()
668 .filter_map(|text_thread| {
669 let text_thread = text_thread.upgrade()?.read(cx);
670 if text_thread.replica_id() == ReplicaId::default() {
671 Some(proto::ContextMetadata {
672 context_id: text_thread.id().to_proto(),
673 summary: text_thread
674 .summary()
675 .content()
676 .map(|summary| summary.text.clone()),
677 })
678 } else {
679 None
680 }
681 })
682 .collect();
683 self.client
684 .send(proto::AdvertiseContexts {
685 project_id,
686 contexts,
687 })
688 .ok();
689 }
690
691 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
692 let Some(project_id) = self.project.read(cx).remote_id() else {
693 return;
694 };
695
696 let text_threads = self
697 .text_threads
698 .iter()
699 .filter_map(|text_thread| {
700 let text_thread = text_thread.upgrade()?.read(cx);
701 if text_thread.replica_id() != ReplicaId::default() {
702 Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
703 } else {
704 None
705 }
706 })
707 .collect();
708
709 let client = self.client.clone();
710 let request = self.client.request(proto::SynchronizeContexts {
711 project_id,
712 contexts: text_threads,
713 });
714 cx.spawn(async move |this, cx| {
715 let response = request.await?;
716
717 let mut text_thread_ids = Vec::new();
718 let mut operations = Vec::new();
719 this.read_with(cx, |this, cx| {
720 for context_version_proto in response.contexts {
721 let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
722 let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
723 if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
724 text_thread_ids.push(text_thread_id);
725 operations
726 .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
727 }
728 }
729 })?;
730
731 let operations = futures::future::join_all(operations).await;
732 for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
733 for operation in operations {
734 client.send(proto::UpdateContext {
735 project_id,
736 context_id: context_id.to_proto(),
737 operation: Some(operation),
738 })?;
739 }
740 }
741
742 anyhow::Ok(())
743 })
744 .detach_and_log_err(cx);
745 }
746
747 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
748 let metadata = self.text_threads_metadata.clone();
749 let executor = cx.background_executor().clone();
750 cx.background_spawn(async move {
751 if query.is_empty() {
752 metadata
753 } else {
754 let candidates = metadata
755 .iter()
756 .enumerate()
757 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
758 .collect::<Vec<_>>();
759 let matches = fuzzy::match_strings(
760 &candidates,
761 &query,
762 false,
763 true,
764 100,
765 &Default::default(),
766 executor,
767 )
768 .await;
769
770 matches
771 .into_iter()
772 .map(|mat| metadata[mat.candidate_id].clone())
773 .collect()
774 }
775 })
776 }
777
778 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
779 let fs = self.fs.clone();
780 cx.spawn(async move |this, cx| {
781 if *ZED_STATELESS {
782 return Ok(());
783 }
784 fs.create_dir(text_threads_dir()).await?;
785
786 let mut paths = fs.read_dir(text_threads_dir()).await?;
787 let mut contexts = Vec::<SavedTextThreadMetadata>::new();
788 while let Some(path) = paths.next().await {
789 let path = path?;
790 if path.extension() != Some(OsStr::new("json")) {
791 continue;
792 }
793
794 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
795 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
796
797 let metadata = fs.metadata(&path).await?;
798 if let Some((file_name, metadata)) = path
799 .file_name()
800 .and_then(|name| name.to_str())
801 .zip(metadata)
802 {
803 // This is used to filter out contexts saved by the new assistant.
804 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
805 continue;
806 }
807
808 if let Some(title) = ASSISTANT_CONTEXT_REGEX
809 .replace(file_name, "")
810 .lines()
811 .next()
812 {
813 contexts.push(SavedTextThreadMetadata {
814 title: title.to_string().into(),
815 path: path.into(),
816 mtime: metadata.mtime.timestamp_for_user().into(),
817 });
818 }
819 }
820 }
821 contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
822
823 this.update(cx, |this, cx| {
824 this.text_threads_metadata = contexts;
825 cx.notify();
826 })
827 })
828 }
829
830 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
831 let context_server_store = self.project.read(cx).context_server_store();
832 cx.subscribe(&context_server_store, Self::handle_context_server_event)
833 .detach();
834
835 // Check for any servers that were already running before the handler was registered
836 for server in context_server_store.read(cx).running_servers() {
837 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
838 }
839 }
840
841 fn handle_context_server_event(
842 &mut self,
843 context_server_store: Entity<ContextServerStore>,
844 event: &project::context_server_store::Event,
845 cx: &mut Context<Self>,
846 ) {
847 match event {
848 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
849 match status {
850 ContextServerStatus::Running => {
851 self.load_context_server_slash_commands(
852 server_id.clone(),
853 context_server_store,
854 cx,
855 );
856 }
857 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
858 if let Some(slash_command_ids) =
859 self.context_server_slash_command_ids.remove(server_id)
860 {
861 self.slash_commands.remove(&slash_command_ids);
862 }
863 }
864 _ => {}
865 }
866 }
867 }
868 }
869
870 fn load_context_server_slash_commands(
871 &self,
872 server_id: ContextServerId,
873 context_server_store: Entity<ContextServerStore>,
874 cx: &mut Context<Self>,
875 ) {
876 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
877 return;
878 };
879 let slash_command_working_set = self.slash_commands.clone();
880 cx.spawn(async move |this, cx| {
881 let Some(protocol) = server.client() else {
882 return;
883 };
884
885 if protocol.capable(context_server::protocol::ServerCapability::Prompts)
886 && let Some(response) = protocol
887 .request::<context_server::types::requests::PromptsList>(())
888 .await
889 .log_err()
890 {
891 let slash_command_ids = response
892 .prompts
893 .into_iter()
894 .filter(assistant_slash_commands::acceptable_prompt)
895 .map(|prompt| {
896 log::info!("registering context server command: {:?}", prompt.name);
897 slash_command_working_set.insert(Arc::new(
898 assistant_slash_commands::ContextServerSlashCommand::new(
899 context_server_store.clone(),
900 server.id(),
901 prompt,
902 ),
903 ))
904 })
905 .collect::<Vec<_>>();
906
907 this.update(cx, |this, _cx| {
908 this.context_server_slash_command_ids
909 .insert(server_id.clone(), slash_command_ids);
910 })
911 .log_err();
912 }
913 })
914 .detach();
915 }
916}