1use crate::{
2 SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
3 TextThreadOperation, TextThreadVersion,
4};
5use anyhow::{Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
8use clock::ReplicaId;
9use collections::HashMap;
10use context_server::ContextServerId;
11use fs::{Fs, RemoveOptions};
12use futures::StreamExt;
13use fuzzy::StringMatchCandidate;
14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
15use language::LanguageRegistry;
16use paths::text_threads_dir;
17use project::{
18 Project,
19 context_server_store::{ContextServerStatus, ContextServerStore},
20};
21use prompt_store::PromptBuilder;
22use regex::Regex;
23use rpc::AnyProtoClient;
24use std::sync::LazyLock;
25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
26use util::{ResultExt, TryFutureExt};
27use zed_env_vars::ZED_STATELESS;
28
29pub(crate) fn init(client: &AnyProtoClient) {
30 client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
31 client.add_entity_request_handler(TextThreadStore::handle_open_context);
32 client.add_entity_request_handler(TextThreadStore::handle_create_context);
33 client.add_entity_message_handler(TextThreadStore::handle_update_context);
34 client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
35}
36
37#[derive(Clone)]
38pub struct RemoteTextThreadMetadata {
39 pub id: TextThreadId,
40 pub summary: Option<String>,
41}
42
43pub struct TextThreadStore {
44 text_threads: Vec<TextThreadHandle>,
45 text_threads_metadata: Vec<SavedTextThreadMetadata>,
46 context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
47 host_text_threads: Vec<RemoteTextThreadMetadata>,
48 fs: Arc<dyn Fs>,
49 languages: Arc<LanguageRegistry>,
50 slash_commands: Arc<SlashCommandWorkingSet>,
51 telemetry: Arc<Telemetry>,
52 _watch_updates: Task<Option<()>>,
53 client: Arc<Client>,
54 project: WeakEntity<Project>,
55 project_is_shared: bool,
56 client_subscription: Option<client::Subscription>,
57 _project_subscriptions: Vec<gpui::Subscription>,
58 prompt_builder: Arc<PromptBuilder>,
59}
60
61enum TextThreadHandle {
62 Weak(WeakEntity<TextThread>),
63 Strong(Entity<TextThread>),
64}
65
66impl TextThreadHandle {
67 fn upgrade(&self) -> Option<Entity<TextThread>> {
68 match self {
69 TextThreadHandle::Weak(weak) => weak.upgrade(),
70 TextThreadHandle::Strong(strong) => Some(strong.clone()),
71 }
72 }
73
74 fn downgrade(&self) -> WeakEntity<TextThread> {
75 match self {
76 TextThreadHandle::Weak(weak) => weak.clone(),
77 TextThreadHandle::Strong(strong) => strong.downgrade(),
78 }
79 }
80}
81
82impl TextThreadStore {
83 pub fn new(
84 project: Entity<Project>,
85 prompt_builder: Arc<PromptBuilder>,
86 slash_commands: Arc<SlashCommandWorkingSet>,
87 cx: &mut App,
88 ) -> Task<Result<Entity<Self>>> {
89 let fs = project.read(cx).fs().clone();
90 let languages = project.read(cx).languages().clone();
91 let telemetry = project.read(cx).client().telemetry().clone();
92 cx.spawn(async move |cx| {
93 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
94 let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
95
96 let this = cx.new(|cx: &mut Context<Self>| {
97 let mut this = Self {
98 text_threads: Vec::new(),
99 text_threads_metadata: Vec::new(),
100 context_server_slash_command_ids: HashMap::default(),
101 host_text_threads: Vec::new(),
102 fs,
103 languages,
104 slash_commands,
105 telemetry,
106 _watch_updates: cx.spawn(async move |this, cx| {
107 async move {
108 while events.next().await.is_some() {
109 this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
110 }
111 anyhow::Ok(())
112 }
113 .log_err()
114 .await
115 }),
116 client_subscription: None,
117 _project_subscriptions: vec![
118 cx.subscribe(&project, Self::handle_project_event),
119 ],
120 project_is_shared: false,
121 client: project.read(cx).client(),
122 project: project.downgrade(),
123 prompt_builder,
124 };
125 this.handle_project_shared(cx);
126 this.synchronize_contexts(cx);
127 this.register_context_server_handlers(cx);
128 this.reload(cx).detach_and_log_err(cx);
129 this
130 })?;
131
132 Ok(this)
133 })
134 }
135
136 #[cfg(any(test, feature = "test-support"))]
137 pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
138 Self {
139 text_threads: Default::default(),
140 text_threads_metadata: Default::default(),
141 context_server_slash_command_ids: Default::default(),
142 host_text_threads: Default::default(),
143 fs: project.read(cx).fs().clone(),
144 languages: project.read(cx).languages().clone(),
145 slash_commands: Arc::default(),
146 telemetry: project.read(cx).client().telemetry().clone(),
147 _watch_updates: Task::ready(None),
148 client: project.read(cx).client(),
149 project: project.downgrade(),
150 project_is_shared: false,
151 client_subscription: None,
152 _project_subscriptions: Default::default(),
153 prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
154 }
155 }
156
157 async fn handle_advertise_contexts(
158 this: Entity<Self>,
159 envelope: TypedEnvelope<proto::AdvertiseContexts>,
160 mut cx: AsyncApp,
161 ) -> Result<()> {
162 this.update(&mut cx, |this, cx| {
163 this.host_text_threads = envelope
164 .payload
165 .contexts
166 .into_iter()
167 .map(|text_thread| RemoteTextThreadMetadata {
168 id: TextThreadId::from_proto(text_thread.context_id),
169 summary: text_thread.summary,
170 })
171 .collect();
172 cx.notify();
173 })
174 }
175
176 async fn handle_open_context(
177 this: Entity<Self>,
178 envelope: TypedEnvelope<proto::OpenContext>,
179 mut cx: AsyncApp,
180 ) -> Result<proto::OpenContextResponse> {
181 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
182 let operations = this.update(&mut cx, |this, cx| {
183 let project = this.project.upgrade().context("project not found")?;
184
185 anyhow::ensure!(
186 !project.read(cx).is_via_collab(),
187 "only the host contexts can be opened"
188 );
189
190 let text_thread = this
191 .loaded_text_thread_for_id(&context_id, cx)
192 .context("context not found")?;
193 anyhow::ensure!(
194 text_thread.read(cx).replica_id() == ReplicaId::default(),
195 "context must be opened via the host"
196 );
197
198 anyhow::Ok(
199 text_thread
200 .read(cx)
201 .serialize_ops(&TextThreadVersion::default(), cx),
202 )
203 })??;
204 let operations = operations.await;
205 Ok(proto::OpenContextResponse {
206 context: Some(proto::Context { operations }),
207 })
208 }
209
210 async fn handle_create_context(
211 this: Entity<Self>,
212 _: TypedEnvelope<proto::CreateContext>,
213 mut cx: AsyncApp,
214 ) -> Result<proto::CreateContextResponse> {
215 let (context_id, operations) = this.update(&mut cx, |this, cx| {
216 let project = this.project.upgrade().context("project not found")?;
217 anyhow::ensure!(
218 !project.read(cx).is_via_collab(),
219 "can only create contexts as the host"
220 );
221
222 let text_thread = this.create(cx);
223 let context_id = text_thread.read(cx).id().clone();
224
225 anyhow::Ok((
226 context_id,
227 text_thread
228 .read(cx)
229 .serialize_ops(&TextThreadVersion::default(), cx),
230 ))
231 })??;
232 let operations = operations.await;
233 Ok(proto::CreateContextResponse {
234 context_id: context_id.to_proto(),
235 context: Some(proto::Context { operations }),
236 })
237 }
238
239 async fn handle_update_context(
240 this: Entity<Self>,
241 envelope: TypedEnvelope<proto::UpdateContext>,
242 mut cx: AsyncApp,
243 ) -> Result<()> {
244 this.update(&mut cx, |this, cx| {
245 let context_id = TextThreadId::from_proto(envelope.payload.context_id);
246 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
247 let operation_proto = envelope.payload.operation.context("invalid operation")?;
248 let operation = TextThreadOperation::from_proto(operation_proto)?;
249 text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
250 }
251 Ok(())
252 })?
253 }
254
255 async fn handle_synchronize_contexts(
256 this: Entity<Self>,
257 envelope: TypedEnvelope<proto::SynchronizeContexts>,
258 mut cx: AsyncApp,
259 ) -> Result<proto::SynchronizeContextsResponse> {
260 this.update(&mut cx, |this, cx| {
261 let project = this.project.upgrade().context("project not found")?;
262 anyhow::ensure!(
263 !project.read(cx).is_via_collab(),
264 "only the host can synchronize contexts"
265 );
266
267 let mut local_versions = Vec::new();
268 for remote_version_proto in envelope.payload.contexts {
269 let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
270 let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
271 if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
272 let text_thread = text_thread.read(cx);
273 let operations = text_thread.serialize_ops(&remote_version, cx);
274 local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
275 let client = this.client.clone();
276 let project_id = envelope.payload.project_id;
277 cx.background_spawn(async move {
278 let operations = operations.await;
279 for operation in operations {
280 client.send(proto::UpdateContext {
281 project_id,
282 context_id: context_id.to_proto(),
283 operation: Some(operation),
284 })?;
285 }
286 anyhow::Ok(())
287 })
288 .detach_and_log_err(cx);
289 }
290 }
291
292 this.advertise_contexts(cx);
293
294 anyhow::Ok(proto::SynchronizeContextsResponse {
295 contexts: local_versions,
296 })
297 })?
298 }
299
300 fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
301 let Some(project) = self.project.upgrade() else {
302 return;
303 };
304
305 let is_shared = project.read(cx).is_shared();
306 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
307 if is_shared == was_shared {
308 return;
309 }
310
311 if is_shared {
312 self.text_threads.retain_mut(|text_thread| {
313 if let Some(strong_context) = text_thread.upgrade() {
314 *text_thread = TextThreadHandle::Strong(strong_context);
315 true
316 } else {
317 false
318 }
319 });
320 let remote_id = project.read(cx).remote_id().unwrap();
321 self.client_subscription = self
322 .client
323 .subscribe_to_entity(remote_id)
324 .log_err()
325 .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
326 self.advertise_contexts(cx);
327 } else {
328 self.client_subscription = None;
329 }
330 }
331
332 fn handle_project_event(
333 &mut self,
334 _project: Entity<Project>,
335 event: &project::Event,
336 cx: &mut Context<Self>,
337 ) {
338 match event {
339 project::Event::RemoteIdChanged(_) => {
340 self.handle_project_shared(cx);
341 }
342 project::Event::Reshared => {
343 self.advertise_contexts(cx);
344 }
345 project::Event::HostReshared | project::Event::Rejoined => {
346 self.synchronize_contexts(cx);
347 }
348 project::Event::DisconnectedFromHost => {
349 self.text_threads.retain_mut(|text_thread| {
350 if let Some(strong_context) = text_thread.upgrade() {
351 *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
352 strong_context.update(cx, |text_thread, cx| {
353 if text_thread.replica_id() != ReplicaId::default() {
354 text_thread.set_capability(language::Capability::ReadOnly, cx);
355 }
356 });
357 true
358 } else {
359 false
360 }
361 });
362 self.host_text_threads.clear();
363 cx.notify();
364 }
365 _ => {}
366 }
367 }
368
369 pub fn unordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
370 self.text_threads_metadata.iter()
371 }
372
373 pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
374 self.host_text_threads.iter()
375 }
376
377 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
378 let context = cx.new(|cx| {
379 TextThread::local(
380 self.languages.clone(),
381 Some(self.project.clone()),
382 Some(self.telemetry.clone()),
383 self.prompt_builder.clone(),
384 self.slash_commands.clone(),
385 cx,
386 )
387 });
388 self.register_text_thread(&context, cx);
389 context
390 }
391
392 pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
393 let Some(project) = self.project.upgrade() else {
394 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
395 };
396 let project = project.read(cx);
397 let Some(project_id) = project.remote_id() else {
398 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
399 };
400
401 let replica_id = project.replica_id();
402 let capability = project.capability();
403 let language_registry = self.languages.clone();
404 let project = self.project.clone();
405 let telemetry = self.telemetry.clone();
406 let prompt_builder = self.prompt_builder.clone();
407 let slash_commands = self.slash_commands.clone();
408 let request = self.client.request(proto::CreateContext { project_id });
409 cx.spawn(async move |this, cx| {
410 let response = request.await?;
411 let context_id = TextThreadId::from_proto(response.context_id);
412 let context_proto = response.context.context("invalid context")?;
413 let text_thread = cx.new(|cx| {
414 TextThread::new(
415 context_id.clone(),
416 replica_id,
417 capability,
418 language_registry,
419 prompt_builder,
420 slash_commands,
421 Some(project),
422 Some(telemetry),
423 cx,
424 )
425 })?;
426 let operations = cx
427 .background_spawn(async move {
428 context_proto
429 .operations
430 .into_iter()
431 .map(TextThreadOperation::from_proto)
432 .collect::<Result<Vec<_>>>()
433 })
434 .await?;
435 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
436 this.update(cx, |this, cx| {
437 if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
438 existing_context
439 } else {
440 this.register_text_thread(&text_thread, cx);
441 this.synchronize_contexts(cx);
442 text_thread
443 }
444 })
445 })
446 }
447
448 pub fn open_local(
449 &mut self,
450 path: Arc<Path>,
451 cx: &Context<Self>,
452 ) -> Task<Result<Entity<TextThread>>> {
453 if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
454 return Task::ready(Ok(existing_context));
455 }
456
457 let fs = self.fs.clone();
458 let languages = self.languages.clone();
459 let project = self.project.clone();
460 let telemetry = self.telemetry.clone();
461 let load = cx.background_spawn({
462 let path = path.clone();
463 async move {
464 let saved_context = fs.load(&path).await?;
465 SavedTextThread::from_json(&saved_context)
466 }
467 });
468 let prompt_builder = self.prompt_builder.clone();
469 let slash_commands = self.slash_commands.clone();
470
471 cx.spawn(async move |this, cx| {
472 let saved_context = load.await?;
473 let context = cx.new(|cx| {
474 TextThread::deserialize(
475 saved_context,
476 path.clone(),
477 languages,
478 prompt_builder,
479 slash_commands,
480 Some(project),
481 Some(telemetry),
482 cx,
483 )
484 })?;
485 this.update(cx, |this, cx| {
486 if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
487 existing_context
488 } else {
489 this.register_text_thread(&context, cx);
490 context
491 }
492 })
493 })
494 }
495
496 pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
497 let fs = self.fs.clone();
498
499 cx.spawn(async move |this, cx| {
500 fs.remove_file(
501 &path,
502 RemoveOptions {
503 recursive: false,
504 ignore_if_not_exists: true,
505 },
506 )
507 .await?;
508
509 this.update(cx, |this, cx| {
510 this.text_threads.retain(|text_thread| {
511 text_thread
512 .upgrade()
513 .and_then(|text_thread| text_thread.read(cx).path())
514 != Some(&path)
515 });
516 this.text_threads_metadata
517 .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
518 })?;
519
520 Ok(())
521 })
522 }
523
524 fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
525 self.text_threads.iter().find_map(|text_thread| {
526 let text_thread = text_thread.upgrade()?;
527 if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
528 Some(text_thread)
529 } else {
530 None
531 }
532 })
533 }
534
535 pub fn loaded_text_thread_for_id(
536 &self,
537 id: &TextThreadId,
538 cx: &App,
539 ) -> Option<Entity<TextThread>> {
540 self.text_threads.iter().find_map(|text_thread| {
541 let text_thread = text_thread.upgrade()?;
542 if text_thread.read(cx).id() == id {
543 Some(text_thread)
544 } else {
545 None
546 }
547 })
548 }
549
550 pub fn open_remote(
551 &mut self,
552 text_thread_id: TextThreadId,
553 cx: &mut Context<Self>,
554 ) -> Task<Result<Entity<TextThread>>> {
555 let Some(project) = self.project.upgrade() else {
556 return Task::ready(Err(anyhow::anyhow!("project was dropped")));
557 };
558 let project = project.read(cx);
559 let Some(project_id) = project.remote_id() else {
560 return Task::ready(Err(anyhow::anyhow!("project was not remote")));
561 };
562
563 if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
564 return Task::ready(Ok(context));
565 }
566
567 let replica_id = project.replica_id();
568 let capability = project.capability();
569 let language_registry = self.languages.clone();
570 let project = self.project.clone();
571 let telemetry = self.telemetry.clone();
572 let request = self.client.request(proto::OpenContext {
573 project_id,
574 context_id: text_thread_id.to_proto(),
575 });
576 let prompt_builder = self.prompt_builder.clone();
577 let slash_commands = self.slash_commands.clone();
578 cx.spawn(async move |this, cx| {
579 let response = request.await?;
580 let context_proto = response.context.context("invalid context")?;
581 let text_thread = cx.new(|cx| {
582 TextThread::new(
583 text_thread_id.clone(),
584 replica_id,
585 capability,
586 language_registry,
587 prompt_builder,
588 slash_commands,
589 Some(project),
590 Some(telemetry),
591 cx,
592 )
593 })?;
594 let operations = cx
595 .background_spawn(async move {
596 context_proto
597 .operations
598 .into_iter()
599 .map(TextThreadOperation::from_proto)
600 .collect::<Result<Vec<_>>>()
601 })
602 .await?;
603 text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
604 this.update(cx, |this, cx| {
605 if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
606 {
607 existing_context
608 } else {
609 this.register_text_thread(&text_thread, cx);
610 this.synchronize_contexts(cx);
611 text_thread
612 }
613 })
614 })
615 }
616
617 fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
618 let handle = if self.project_is_shared {
619 TextThreadHandle::Strong(text_thread.clone())
620 } else {
621 TextThreadHandle::Weak(text_thread.downgrade())
622 };
623 self.text_threads.push(handle);
624 self.advertise_contexts(cx);
625 cx.subscribe(text_thread, Self::handle_context_event)
626 .detach();
627 }
628
629 fn handle_context_event(
630 &mut self,
631 text_thread: Entity<TextThread>,
632 event: &TextThreadEvent,
633 cx: &mut Context<Self>,
634 ) {
635 let Some(project) = self.project.upgrade() else {
636 return;
637 };
638 let Some(project_id) = project.read(cx).remote_id() else {
639 return;
640 };
641
642 match event {
643 TextThreadEvent::SummaryChanged => {
644 self.advertise_contexts(cx);
645 }
646 TextThreadEvent::PathChanged { old_path, new_path } => {
647 if let Some(old_path) = old_path.as_ref() {
648 for metadata in &mut self.text_threads_metadata {
649 if &metadata.path == old_path {
650 metadata.path = new_path.clone();
651 break;
652 }
653 }
654 }
655 }
656 TextThreadEvent::Operation(operation) => {
657 let context_id = text_thread.read(cx).id().to_proto();
658 let operation = operation.to_proto();
659 self.client
660 .send(proto::UpdateContext {
661 project_id,
662 context_id,
663 operation: Some(operation),
664 })
665 .log_err();
666 }
667 _ => {}
668 }
669 }
670
671 fn advertise_contexts(&self, cx: &App) {
672 let Some(project) = self.project.upgrade() else {
673 return;
674 };
675 let Some(project_id) = project.read(cx).remote_id() else {
676 return;
677 };
678 // For now, only the host can advertise their open contexts.
679 if project.read(cx).is_via_collab() {
680 return;
681 }
682
683 let contexts = self
684 .text_threads
685 .iter()
686 .rev()
687 .filter_map(|text_thread| {
688 let text_thread = text_thread.upgrade()?.read(cx);
689 if text_thread.replica_id() == ReplicaId::default() {
690 Some(proto::ContextMetadata {
691 context_id: text_thread.id().to_proto(),
692 summary: text_thread
693 .summary()
694 .content()
695 .map(|summary| summary.text.clone()),
696 })
697 } else {
698 None
699 }
700 })
701 .collect();
702 self.client
703 .send(proto::AdvertiseContexts {
704 project_id,
705 contexts,
706 })
707 .ok();
708 }
709
710 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
711 let Some(project) = self.project.upgrade() else {
712 return;
713 };
714 let Some(project_id) = project.read(cx).remote_id() else {
715 return;
716 };
717
718 let text_threads = self
719 .text_threads
720 .iter()
721 .filter_map(|text_thread| {
722 let text_thread = text_thread.upgrade()?.read(cx);
723 if text_thread.replica_id() != ReplicaId::default() {
724 Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
725 } else {
726 None
727 }
728 })
729 .collect();
730
731 let client = self.client.clone();
732 let request = self.client.request(proto::SynchronizeContexts {
733 project_id,
734 contexts: text_threads,
735 });
736 cx.spawn(async move |this, cx| {
737 let response = request.await?;
738
739 let mut text_thread_ids = Vec::new();
740 let mut operations = Vec::new();
741 this.read_with(cx, |this, cx| {
742 for context_version_proto in response.contexts {
743 let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
744 let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
745 if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
746 text_thread_ids.push(text_thread_id);
747 operations
748 .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
749 }
750 }
751 })?;
752
753 let operations = futures::future::join_all(operations).await;
754 for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
755 for operation in operations {
756 client.send(proto::UpdateContext {
757 project_id,
758 context_id: context_id.to_proto(),
759 operation: Some(operation),
760 })?;
761 }
762 }
763
764 anyhow::Ok(())
765 })
766 .detach_and_log_err(cx);
767 }
768
769 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
770 let metadata = self.text_threads_metadata.clone();
771 let executor = cx.background_executor().clone();
772 cx.background_spawn(async move {
773 if query.is_empty() {
774 metadata
775 } else {
776 let candidates = metadata
777 .iter()
778 .enumerate()
779 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
780 .collect::<Vec<_>>();
781 let matches = fuzzy::match_strings(
782 &candidates,
783 &query,
784 false,
785 true,
786 100,
787 &Default::default(),
788 executor,
789 )
790 .await;
791
792 matches
793 .into_iter()
794 .map(|mat| metadata[mat.candidate_id].clone())
795 .collect()
796 }
797 })
798 }
799
800 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
801 let fs = self.fs.clone();
802 cx.spawn(async move |this, cx| {
803 if *ZED_STATELESS {
804 return Ok(());
805 }
806 fs.create_dir(text_threads_dir()).await?;
807
808 let mut paths = fs.read_dir(text_threads_dir()).await?;
809 let mut contexts = Vec::<SavedTextThreadMetadata>::new();
810 while let Some(path) = paths.next().await {
811 let path = path?;
812 if path.extension() != Some(OsStr::new("json")) {
813 continue;
814 }
815
816 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
817 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
818
819 let metadata = fs.metadata(&path).await?;
820 if let Some((file_name, metadata)) = path
821 .file_name()
822 .and_then(|name| name.to_str())
823 .zip(metadata)
824 {
825 // This is used to filter out contexts saved by the new assistant.
826 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
827 continue;
828 }
829
830 if let Some(title) = ASSISTANT_CONTEXT_REGEX
831 .replace(file_name, "")
832 .lines()
833 .next()
834 {
835 contexts.push(SavedTextThreadMetadata {
836 title: title.to_string().into(),
837 path: path.into(),
838 mtime: metadata.mtime.timestamp_for_user().into(),
839 });
840 }
841 }
842 }
843 contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
844
845 this.update(cx, |this, cx| {
846 this.text_threads_metadata = contexts;
847 cx.notify();
848 })
849 })
850 }
851
852 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
853 let Some(project) = self.project.upgrade() else {
854 return;
855 };
856 let context_server_store = project.read(cx).context_server_store();
857 cx.subscribe(&context_server_store, Self::handle_context_server_event)
858 .detach();
859
860 // Check for any servers that were already running before the handler was registered
861 for server in context_server_store.read(cx).running_servers() {
862 self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
863 }
864 }
865
866 fn handle_context_server_event(
867 &mut self,
868 context_server_store: Entity<ContextServerStore>,
869 event: &project::context_server_store::Event,
870 cx: &mut Context<Self>,
871 ) {
872 match event {
873 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
874 match status {
875 ContextServerStatus::Running => {
876 self.load_context_server_slash_commands(
877 server_id.clone(),
878 context_server_store,
879 cx,
880 );
881 }
882 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
883 if let Some(slash_command_ids) =
884 self.context_server_slash_command_ids.remove(server_id)
885 {
886 self.slash_commands.remove(&slash_command_ids);
887 }
888 }
889 _ => {}
890 }
891 }
892 }
893 }
894
895 fn load_context_server_slash_commands(
896 &self,
897 server_id: ContextServerId,
898 context_server_store: Entity<ContextServerStore>,
899 cx: &mut Context<Self>,
900 ) {
901 let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
902 return;
903 };
904 let slash_command_working_set = self.slash_commands.clone();
905 cx.spawn(async move |this, cx| {
906 let Some(protocol) = server.client() else {
907 return;
908 };
909
910 if protocol.capable(context_server::protocol::ServerCapability::Prompts)
911 && let Some(response) = protocol
912 .request::<context_server::types::requests::PromptsList>(())
913 .await
914 .log_err()
915 {
916 let slash_command_ids = response
917 .prompts
918 .into_iter()
919 .filter(assistant_slash_commands::acceptable_prompt)
920 .map(|prompt| {
921 log::info!("registering context server command: {:?}", prompt.name);
922 slash_command_working_set.insert(Arc::new(
923 assistant_slash_commands::ContextServerSlashCommand::new(
924 context_server_store.clone(),
925 server.id(),
926 prompt,
927 ),
928 ))
929 })
930 .collect::<Vec<_>>();
931
932 this.update(cx, |this, _cx| {
933 this.context_server_slash_command_ids
934 .insert(server_id.clone(), slash_command_ids);
935 })
936 .log_err();
937 }
938 })
939 .detach();
940 }
941}