text_thread_store.rs

  1use crate::{
  2    SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
  3    TextThreadOperation, TextThreadVersion,
  4};
  5use anyhow::{Context as _, Result};
  6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
  7use client::{Client, TypedEnvelope, proto};
  8use clock::ReplicaId;
  9use collections::HashMap;
 10use context_server::ContextServerId;
 11use fs::{Fs, RemoveOptions};
 12use futures::StreamExt;
 13use fuzzy::StringMatchCandidate;
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 15use language::LanguageRegistry;
 16use paths::text_threads_dir;
 17use project::{
 18    Project,
 19    context_server_store::{ContextServerStatus, ContextServerStore},
 20};
 21use prompt_store::PromptBuilder;
 22use regex::Regex;
 23use rpc::AnyProtoClient;
 24use std::sync::LazyLock;
 25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
 26use util::{ResultExt, TryFutureExt};
 27use zed_env_vars::ZED_STATELESS;
 28
 29pub(crate) fn init(client: &AnyProtoClient) {
 30    client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
 31    client.add_entity_request_handler(TextThreadStore::handle_open_context);
 32    client.add_entity_request_handler(TextThreadStore::handle_create_context);
 33    client.add_entity_message_handler(TextThreadStore::handle_update_context);
 34    client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
 35}
 36
 37#[derive(Clone)]
 38pub struct RemoteTextThreadMetadata {
 39    pub id: TextThreadId,
 40    pub summary: Option<String>,
 41}
 42
 43pub struct TextThreadStore {
 44    text_threads: Vec<TextThreadHandle>,
 45    text_threads_metadata: Vec<SavedTextThreadMetadata>,
 46    context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
 47    host_text_threads: Vec<RemoteTextThreadMetadata>,
 48    fs: Arc<dyn Fs>,
 49    languages: Arc<LanguageRegistry>,
 50    slash_commands: Arc<SlashCommandWorkingSet>,
 51    _watch_updates: Task<Option<()>>,
 52    client: Arc<Client>,
 53    project: WeakEntity<Project>,
 54    project_is_shared: bool,
 55    client_subscription: Option<client::Subscription>,
 56    _project_subscriptions: Vec<gpui::Subscription>,
 57    prompt_builder: Arc<PromptBuilder>,
 58}
 59
 60enum TextThreadHandle {
 61    Weak(WeakEntity<TextThread>),
 62    Strong(Entity<TextThread>),
 63}
 64
 65impl TextThreadHandle {
 66    fn upgrade(&self) -> Option<Entity<TextThread>> {
 67        match self {
 68            TextThreadHandle::Weak(weak) => weak.upgrade(),
 69            TextThreadHandle::Strong(strong) => Some(strong.clone()),
 70        }
 71    }
 72
 73    fn downgrade(&self) -> WeakEntity<TextThread> {
 74        match self {
 75            TextThreadHandle::Weak(weak) => weak.clone(),
 76            TextThreadHandle::Strong(strong) => strong.downgrade(),
 77        }
 78    }
 79}
 80
 81impl TextThreadStore {
 82    pub fn new(
 83        project: Entity<Project>,
 84        prompt_builder: Arc<PromptBuilder>,
 85        slash_commands: Arc<SlashCommandWorkingSet>,
 86        cx: &mut App,
 87    ) -> Task<Result<Entity<Self>>> {
 88        let fs = project.read(cx).fs().clone();
 89        let languages = project.read(cx).languages().clone();
 90        cx.spawn(async move |cx| {
 91            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 92            let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
 93
 94            let this = cx.new(|cx: &mut Context<Self>| {
 95                let mut this = Self {
 96                    text_threads: Vec::new(),
 97                    text_threads_metadata: Vec::new(),
 98                    context_server_slash_command_ids: HashMap::default(),
 99                    host_text_threads: Vec::new(),
100                    fs,
101                    languages,
102                    slash_commands,
103                    _watch_updates: cx.spawn(async move |this, cx| {
104                        async move {
105                            while events.next().await.is_some() {
106                                this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
107                            }
108                            anyhow::Ok(())
109                        }
110                        .log_err()
111                        .await
112                    }),
113                    client_subscription: None,
114                    _project_subscriptions: vec![
115                        cx.subscribe(&project, Self::handle_project_event),
116                    ],
117                    project_is_shared: false,
118                    client: project.read(cx).client(),
119                    project: project.downgrade(),
120                    prompt_builder,
121                };
122                this.handle_project_shared(cx);
123                this.synchronize_contexts(cx);
124                this.register_context_server_handlers(cx);
125                this.reload(cx).detach_and_log_err(cx);
126                this
127            });
128
129            Ok(this)
130        })
131    }
132
133    #[cfg(any(test, feature = "test-support"))]
134    pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
135        Self {
136            text_threads: Default::default(),
137            text_threads_metadata: Default::default(),
138            context_server_slash_command_ids: Default::default(),
139            host_text_threads: Default::default(),
140            fs: project.read(cx).fs().clone(),
141            languages: project.read(cx).languages().clone(),
142            slash_commands: Arc::default(),
143            _watch_updates: Task::ready(None),
144            client: project.read(cx).client(),
145            project: project.downgrade(),
146            project_is_shared: false,
147            client_subscription: None,
148            _project_subscriptions: Default::default(),
149            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
150        }
151    }
152
153    async fn handle_advertise_contexts(
154        this: Entity<Self>,
155        envelope: TypedEnvelope<proto::AdvertiseContexts>,
156        mut cx: AsyncApp,
157    ) -> Result<()> {
158        this.update(&mut cx, |this, cx| {
159            this.host_text_threads = envelope
160                .payload
161                .contexts
162                .into_iter()
163                .map(|text_thread| RemoteTextThreadMetadata {
164                    id: TextThreadId::from_proto(text_thread.context_id),
165                    summary: text_thread.summary,
166                })
167                .collect();
168            cx.notify();
169        });
170        Ok(())
171    }
172
173    async fn handle_open_context(
174        this: Entity<Self>,
175        envelope: TypedEnvelope<proto::OpenContext>,
176        mut cx: AsyncApp,
177    ) -> Result<proto::OpenContextResponse> {
178        let context_id = TextThreadId::from_proto(envelope.payload.context_id);
179        let operations = this.update(&mut cx, |this, cx| {
180            let project = this.project.upgrade().context("project not found")?;
181
182            anyhow::ensure!(
183                !project.read(cx).is_via_collab(),
184                "only the host contexts can be opened"
185            );
186
187            let text_thread = this
188                .loaded_text_thread_for_id(&context_id, cx)
189                .context("context not found")?;
190            anyhow::ensure!(
191                text_thread.read(cx).replica_id() == ReplicaId::default(),
192                "context must be opened via the host"
193            );
194
195            anyhow::Ok(
196                text_thread
197                    .read(cx)
198                    .serialize_ops(&TextThreadVersion::default(), cx),
199            )
200        })?;
201        let operations = operations.await;
202        Ok(proto::OpenContextResponse {
203            context: Some(proto::Context { operations }),
204        })
205    }
206
207    async fn handle_create_context(
208        this: Entity<Self>,
209        _: TypedEnvelope<proto::CreateContext>,
210        mut cx: AsyncApp,
211    ) -> Result<proto::CreateContextResponse> {
212        let (context_id, operations) = this.update(&mut cx, |this, cx| {
213            let project = this.project.upgrade().context("project not found")?;
214            anyhow::ensure!(
215                !project.read(cx).is_via_collab(),
216                "can only create contexts as the host"
217            );
218
219            let text_thread = this.create(cx);
220            let context_id = text_thread.read(cx).id().clone();
221
222            anyhow::Ok((
223                context_id,
224                text_thread
225                    .read(cx)
226                    .serialize_ops(&TextThreadVersion::default(), cx),
227            ))
228        })?;
229        let operations = operations.await;
230        Ok(proto::CreateContextResponse {
231            context_id: context_id.to_proto(),
232            context: Some(proto::Context { operations }),
233        })
234    }
235
236    async fn handle_update_context(
237        this: Entity<Self>,
238        envelope: TypedEnvelope<proto::UpdateContext>,
239        mut cx: AsyncApp,
240    ) -> Result<()> {
241        this.update(&mut cx, |this, cx| {
242            let context_id = TextThreadId::from_proto(envelope.payload.context_id);
243            if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
244                let operation_proto = envelope.payload.operation.context("invalid operation")?;
245                let operation = TextThreadOperation::from_proto(operation_proto)?;
246                text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
247            }
248            Ok(())
249        })
250    }
251
252    async fn handle_synchronize_contexts(
253        this: Entity<Self>,
254        envelope: TypedEnvelope<proto::SynchronizeContexts>,
255        mut cx: AsyncApp,
256    ) -> Result<proto::SynchronizeContextsResponse> {
257        this.update(&mut cx, |this, cx| {
258            let project = this.project.upgrade().context("project not found")?;
259            anyhow::ensure!(
260                !project.read(cx).is_via_collab(),
261                "only the host can synchronize contexts"
262            );
263
264            let mut local_versions = Vec::new();
265            for remote_version_proto in envelope.payload.contexts {
266                let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
267                let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
268                if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
269                    let text_thread = text_thread.read(cx);
270                    let operations = text_thread.serialize_ops(&remote_version, cx);
271                    local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
272                    let client = this.client.clone();
273                    let project_id = envelope.payload.project_id;
274                    cx.background_spawn(async move {
275                        let operations = operations.await;
276                        for operation in operations {
277                            client.send(proto::UpdateContext {
278                                project_id,
279                                context_id: context_id.to_proto(),
280                                operation: Some(operation),
281                            })?;
282                        }
283                        anyhow::Ok(())
284                    })
285                    .detach_and_log_err(cx);
286                }
287            }
288
289            this.advertise_contexts(cx);
290
291            anyhow::Ok(proto::SynchronizeContextsResponse {
292                contexts: local_versions,
293            })
294        })
295    }
296
297    fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
298        let Some(project) = self.project.upgrade() else {
299            return;
300        };
301
302        let is_shared = project.read(cx).is_shared();
303        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
304        if is_shared == was_shared {
305            return;
306        }
307
308        if is_shared {
309            self.text_threads.retain_mut(|text_thread| {
310                if let Some(strong_context) = text_thread.upgrade() {
311                    *text_thread = TextThreadHandle::Strong(strong_context);
312                    true
313                } else {
314                    false
315                }
316            });
317            let remote_id = project.read(cx).remote_id().unwrap();
318            self.client_subscription = self
319                .client
320                .subscribe_to_entity(remote_id)
321                .log_err()
322                .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
323            self.advertise_contexts(cx);
324        } else {
325            self.client_subscription = None;
326        }
327    }
328
329    fn handle_project_event(
330        &mut self,
331        _project: Entity<Project>,
332        event: &project::Event,
333        cx: &mut Context<Self>,
334    ) {
335        match event {
336            project::Event::RemoteIdChanged(_) => {
337                self.handle_project_shared(cx);
338            }
339            project::Event::Reshared => {
340                self.advertise_contexts(cx);
341            }
342            project::Event::HostReshared | project::Event::Rejoined => {
343                self.synchronize_contexts(cx);
344            }
345            project::Event::DisconnectedFromHost => {
346                self.text_threads.retain_mut(|text_thread| {
347                    if let Some(strong_context) = text_thread.upgrade() {
348                        *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
349                        strong_context.update(cx, |text_thread, cx| {
350                            if text_thread.replica_id() != ReplicaId::default() {
351                                text_thread.set_capability(language::Capability::ReadOnly, cx);
352                            }
353                        });
354                        true
355                    } else {
356                        false
357                    }
358                });
359                self.host_text_threads.clear();
360                cx.notify();
361            }
362            _ => {}
363        }
364    }
365
366    pub fn unordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
367        self.text_threads_metadata.iter()
368    }
369
370    pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
371        self.host_text_threads.iter()
372    }
373
374    pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
375        let context = cx.new(|cx| {
376            TextThread::local(
377                self.languages.clone(),
378                Some(self.project.clone()),
379                self.prompt_builder.clone(),
380                self.slash_commands.clone(),
381                cx,
382            )
383        });
384        self.register_text_thread(&context, cx);
385        context
386    }
387
388    pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
389        let Some(project) = self.project.upgrade() else {
390            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
391        };
392        let project = project.read(cx);
393        let Some(project_id) = project.remote_id() else {
394            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
395        };
396
397        let replica_id = project.replica_id();
398        let capability = project.capability();
399        let language_registry = self.languages.clone();
400        let project = self.project.clone();
401
402        let prompt_builder = self.prompt_builder.clone();
403        let slash_commands = self.slash_commands.clone();
404        let request = self.client.request(proto::CreateContext { project_id });
405        cx.spawn(async move |this, cx| {
406            let response = request.await?;
407            let context_id = TextThreadId::from_proto(response.context_id);
408            let context_proto = response.context.context("invalid context")?;
409            let text_thread = cx.new(|cx| {
410                TextThread::new(
411                    context_id.clone(),
412                    replica_id,
413                    capability,
414                    language_registry,
415                    prompt_builder,
416                    slash_commands,
417                    Some(project),
418                    cx,
419                )
420            });
421            let operations = cx
422                .background_spawn(async move {
423                    context_proto
424                        .operations
425                        .into_iter()
426                        .map(TextThreadOperation::from_proto)
427                        .collect::<Result<Vec<_>>>()
428                })
429                .await?;
430            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
431            this.update(cx, |this, cx| {
432                if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
433                    existing_context
434                } else {
435                    this.register_text_thread(&text_thread, cx);
436                    this.synchronize_contexts(cx);
437                    text_thread
438                }
439            })
440        })
441    }
442
443    pub fn open_local(
444        &mut self,
445        path: Arc<Path>,
446        cx: &Context<Self>,
447    ) -> Task<Result<Entity<TextThread>>> {
448        if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
449            return Task::ready(Ok(existing_context));
450        }
451
452        let fs = self.fs.clone();
453        let languages = self.languages.clone();
454        let project = self.project.clone();
455        let load = cx.background_spawn({
456            let path = path.clone();
457            async move {
458                let saved_context = fs.load(&path).await?;
459                SavedTextThread::from_json(&saved_context)
460            }
461        });
462        let prompt_builder = self.prompt_builder.clone();
463        let slash_commands = self.slash_commands.clone();
464
465        cx.spawn(async move |this, cx| {
466            let saved_context = load.await?;
467            let context = cx.new(|cx| {
468                TextThread::deserialize(
469                    saved_context,
470                    path.clone(),
471                    languages,
472                    prompt_builder,
473                    slash_commands,
474                    Some(project),
475                    cx,
476                )
477            });
478            this.update(cx, |this, cx| {
479                if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
480                    existing_context
481                } else {
482                    this.register_text_thread(&context, cx);
483                    context
484                }
485            })
486        })
487    }
488
489    pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
490        let fs = self.fs.clone();
491
492        cx.spawn(async move |this, cx| {
493            fs.remove_file(
494                &path,
495                RemoveOptions {
496                    recursive: false,
497                    ignore_if_not_exists: true,
498                },
499            )
500            .await?;
501
502            this.update(cx, |this, cx| {
503                this.text_threads.retain(|text_thread| {
504                    text_thread
505                        .upgrade()
506                        .and_then(|text_thread| text_thread.read(cx).path())
507                        != Some(&path)
508                });
509                this.text_threads_metadata
510                    .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
511            })?;
512
513            Ok(())
514        })
515    }
516
517    fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
518        self.text_threads.iter().find_map(|text_thread| {
519            let text_thread = text_thread.upgrade()?;
520            if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
521                Some(text_thread)
522            } else {
523                None
524            }
525        })
526    }
527
528    pub fn loaded_text_thread_for_id(
529        &self,
530        id: &TextThreadId,
531        cx: &App,
532    ) -> Option<Entity<TextThread>> {
533        self.text_threads.iter().find_map(|text_thread| {
534            let text_thread = text_thread.upgrade()?;
535            if text_thread.read(cx).id() == id {
536                Some(text_thread)
537            } else {
538                None
539            }
540        })
541    }
542
543    pub fn open_remote(
544        &mut self,
545        text_thread_id: TextThreadId,
546        cx: &mut Context<Self>,
547    ) -> Task<Result<Entity<TextThread>>> {
548        let Some(project) = self.project.upgrade() else {
549            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
550        };
551        let project = project.read(cx);
552        let Some(project_id) = project.remote_id() else {
553            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
554        };
555
556        if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
557            return Task::ready(Ok(context));
558        }
559
560        let replica_id = project.replica_id();
561        let capability = project.capability();
562        let language_registry = self.languages.clone();
563        let project = self.project.clone();
564        let request = self.client.request(proto::OpenContext {
565            project_id,
566            context_id: text_thread_id.to_proto(),
567        });
568        let prompt_builder = self.prompt_builder.clone();
569        let slash_commands = self.slash_commands.clone();
570        cx.spawn(async move |this, cx| {
571            let response = request.await?;
572            let context_proto = response.context.context("invalid context")?;
573            let text_thread = cx.new(|cx| {
574                TextThread::new(
575                    text_thread_id.clone(),
576                    replica_id,
577                    capability,
578                    language_registry,
579                    prompt_builder,
580                    slash_commands,
581                    Some(project),
582                    cx,
583                )
584            });
585            let operations = cx
586                .background_spawn(async move {
587                    context_proto
588                        .operations
589                        .into_iter()
590                        .map(TextThreadOperation::from_proto)
591                        .collect::<Result<Vec<_>>>()
592                })
593                .await?;
594            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
595            this.update(cx, |this, cx| {
596                if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
597                {
598                    existing_context
599                } else {
600                    this.register_text_thread(&text_thread, cx);
601                    this.synchronize_contexts(cx);
602                    text_thread
603                }
604            })
605        })
606    }
607
608    fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
609        let handle = if self.project_is_shared {
610            TextThreadHandle::Strong(text_thread.clone())
611        } else {
612            TextThreadHandle::Weak(text_thread.downgrade())
613        };
614        self.text_threads.push(handle);
615        self.advertise_contexts(cx);
616        cx.subscribe(text_thread, Self::handle_context_event)
617            .detach();
618    }
619
620    fn handle_context_event(
621        &mut self,
622        text_thread: Entity<TextThread>,
623        event: &TextThreadEvent,
624        cx: &mut Context<Self>,
625    ) {
626        let Some(project) = self.project.upgrade() else {
627            return;
628        };
629        let Some(project_id) = project.read(cx).remote_id() else {
630            return;
631        };
632
633        match event {
634            TextThreadEvent::SummaryChanged => {
635                self.advertise_contexts(cx);
636            }
637            TextThreadEvent::PathChanged { old_path, new_path } => {
638                if let Some(old_path) = old_path.as_ref() {
639                    for metadata in &mut self.text_threads_metadata {
640                        if &metadata.path == old_path {
641                            metadata.path = new_path.clone();
642                            break;
643                        }
644                    }
645                }
646            }
647            TextThreadEvent::Operation(operation) => {
648                let context_id = text_thread.read(cx).id().to_proto();
649                let operation = operation.to_proto();
650                self.client
651                    .send(proto::UpdateContext {
652                        project_id,
653                        context_id,
654                        operation: Some(operation),
655                    })
656                    .log_err();
657            }
658            _ => {}
659        }
660    }
661
662    fn advertise_contexts(&self, cx: &App) {
663        let Some(project) = self.project.upgrade() else {
664            return;
665        };
666        let Some(project_id) = project.read(cx).remote_id() else {
667            return;
668        };
669        // For now, only the host can advertise their open contexts.
670        if project.read(cx).is_via_collab() {
671            return;
672        }
673
674        let contexts = self
675            .text_threads
676            .iter()
677            .rev()
678            .filter_map(|text_thread| {
679                let text_thread = text_thread.upgrade()?.read(cx);
680                if text_thread.replica_id() == ReplicaId::default() {
681                    Some(proto::ContextMetadata {
682                        context_id: text_thread.id().to_proto(),
683                        summary: text_thread
684                            .summary()
685                            .content()
686                            .map(|summary| summary.text.clone()),
687                    })
688                } else {
689                    None
690                }
691            })
692            .collect();
693        self.client
694            .send(proto::AdvertiseContexts {
695                project_id,
696                contexts,
697            })
698            .ok();
699    }
700
701    fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
702        let Some(project) = self.project.upgrade() else {
703            return;
704        };
705        let Some(project_id) = project.read(cx).remote_id() else {
706            return;
707        };
708
709        let text_threads = self
710            .text_threads
711            .iter()
712            .filter_map(|text_thread| {
713                let text_thread = text_thread.upgrade()?.read(cx);
714                if text_thread.replica_id() != ReplicaId::default() {
715                    Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
716                } else {
717                    None
718                }
719            })
720            .collect();
721
722        let client = self.client.clone();
723        let request = self.client.request(proto::SynchronizeContexts {
724            project_id,
725            contexts: text_threads,
726        });
727        cx.spawn(async move |this, cx| {
728            let response = request.await?;
729
730            let mut text_thread_ids = Vec::new();
731            let mut operations = Vec::new();
732            this.read_with(cx, |this, cx| {
733                for context_version_proto in response.contexts {
734                    let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
735                    let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
736                    if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
737                        text_thread_ids.push(text_thread_id);
738                        operations
739                            .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
740                    }
741                }
742            })?;
743
744            let operations = futures::future::join_all(operations).await;
745            for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
746                for operation in operations {
747                    client.send(proto::UpdateContext {
748                        project_id,
749                        context_id: context_id.to_proto(),
750                        operation: Some(operation),
751                    })?;
752                }
753            }
754
755            anyhow::Ok(())
756        })
757        .detach_and_log_err(cx);
758    }
759
760    pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
761        let metadata = self.text_threads_metadata.clone();
762        let executor = cx.background_executor().clone();
763        cx.background_spawn(async move {
764            if query.is_empty() {
765                metadata
766            } else {
767                let candidates = metadata
768                    .iter()
769                    .enumerate()
770                    .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
771                    .collect::<Vec<_>>();
772                let matches = fuzzy::match_strings(
773                    &candidates,
774                    &query,
775                    false,
776                    true,
777                    100,
778                    &Default::default(),
779                    executor,
780                )
781                .await;
782
783                matches
784                    .into_iter()
785                    .map(|mat| metadata[mat.candidate_id].clone())
786                    .collect()
787            }
788        })
789    }
790
791    fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
792        let fs = self.fs.clone();
793        cx.spawn(async move |this, cx| {
794            if *ZED_STATELESS {
795                return Ok(());
796            }
797            fs.create_dir(text_threads_dir()).await?;
798
799            let mut paths = fs.read_dir(text_threads_dir()).await?;
800            let mut contexts = Vec::<SavedTextThreadMetadata>::new();
801            while let Some(path) = paths.next().await {
802                let path = path?;
803                if path.extension() != Some(OsStr::new("json")) {
804                    continue;
805                }
806
807                static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
808                    LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
809
810                let metadata = fs.metadata(&path).await?;
811                if let Some((file_name, metadata)) = path
812                    .file_name()
813                    .and_then(|name| name.to_str())
814                    .zip(metadata)
815                {
816                    // This is used to filter out contexts saved by the new assistant.
817                    if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
818                        continue;
819                    }
820
821                    if let Some(title) = ASSISTANT_CONTEXT_REGEX
822                        .replace(file_name, "")
823                        .lines()
824                        .next()
825                    {
826                        contexts.push(SavedTextThreadMetadata {
827                            title: title.to_string().into(),
828                            path: path.into(),
829                            mtime: metadata.mtime.timestamp_for_user().into(),
830                        });
831                    }
832                }
833            }
834            contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
835
836            this.update(cx, |this, cx| {
837                this.text_threads_metadata = contexts;
838                cx.notify();
839            })
840        })
841    }
842
843    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
844        let Some(project) = self.project.upgrade() else {
845            return;
846        };
847        let context_server_store = project.read(cx).context_server_store();
848        cx.subscribe(&context_server_store, Self::handle_context_server_event)
849            .detach();
850
851        // Check for any servers that were already running before the handler was registered
852        for server in context_server_store.read(cx).running_servers() {
853            self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
854        }
855    }
856
857    fn handle_context_server_event(
858        &mut self,
859        context_server_store: Entity<ContextServerStore>,
860        event: &project::context_server_store::Event,
861        cx: &mut Context<Self>,
862    ) {
863        match event {
864            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
865                match status {
866                    ContextServerStatus::Running => {
867                        self.load_context_server_slash_commands(
868                            server_id.clone(),
869                            context_server_store,
870                            cx,
871                        );
872                    }
873                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
874                        if let Some(slash_command_ids) =
875                            self.context_server_slash_command_ids.remove(server_id)
876                        {
877                            self.slash_commands.remove(&slash_command_ids);
878                        }
879                    }
880                    _ => {}
881                }
882            }
883        }
884    }
885
886    fn load_context_server_slash_commands(
887        &self,
888        server_id: ContextServerId,
889        context_server_store: Entity<ContextServerStore>,
890        cx: &mut Context<Self>,
891    ) {
892        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
893            return;
894        };
895        let slash_command_working_set = self.slash_commands.clone();
896        cx.spawn(async move |this, cx| {
897            let Some(protocol) = server.client() else {
898                return;
899            };
900
901            if protocol.capable(context_server::protocol::ServerCapability::Prompts)
902                && let Some(response) = protocol
903                    .request::<context_server::types::requests::PromptsList>(())
904                    .await
905                    .log_err()
906            {
907                let slash_command_ids = response
908                    .prompts
909                    .into_iter()
910                    .filter(assistant_slash_commands::acceptable_prompt)
911                    .map(|prompt| {
912                        log::info!("registering context server command: {:?}", prompt.name);
913                        slash_command_working_set.insert(Arc::new(
914                            assistant_slash_commands::ContextServerSlashCommand::new(
915                                context_server_store.clone(),
916                                server.id(),
917                                prompt,
918                            ),
919                        ))
920                    })
921                    .collect::<Vec<_>>();
922
923                this.update(cx, |this, _cx| {
924                    this.context_server_slash_command_ids
925                        .insert(server_id.clone(), slash_command_ids);
926                })
927                .log_err();
928            }
929        })
930        .detach();
931    }
932}