text_thread_store.rs

  1use crate::{
  2    SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
  3    TextThreadOperation, TextThreadVersion,
  4};
  5use anyhow::{Context as _, Result};
  6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
  7use client::{Client, TypedEnvelope, proto};
  8use clock::ReplicaId;
  9use collections::HashMap;
 10use context_server::ContextServerId;
 11use fs::{Fs, RemoveOptions};
 12use futures::StreamExt;
 13use fuzzy::StringMatchCandidate;
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
 15use language::LanguageRegistry;
 16use paths::text_threads_dir;
 17use project::{
 18    Project,
 19    context_server_store::{ContextServerStatus, ContextServerStore},
 20};
 21use prompt_store::PromptBuilder;
 22use regex::Regex;
 23use rpc::AnyProtoClient;
 24use std::sync::LazyLock;
 25use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
 26use util::{ResultExt, TryFutureExt};
 27use zed_env_vars::ZED_STATELESS;
 28
 29pub(crate) fn init(client: &AnyProtoClient) {
 30    client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
 31    client.add_entity_request_handler(TextThreadStore::handle_open_context);
 32    client.add_entity_request_handler(TextThreadStore::handle_create_context);
 33    client.add_entity_message_handler(TextThreadStore::handle_update_context);
 34    client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
 35}
 36
 37#[derive(Clone)]
 38pub struct RemoteTextThreadMetadata {
 39    pub id: TextThreadId,
 40    pub summary: Option<String>,
 41}
 42
 43pub struct TextThreadStore {
 44    text_threads: Vec<TextThreadHandle>,
 45    text_threads_metadata: Vec<SavedTextThreadMetadata>,
 46    context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
 47    host_text_threads: Vec<RemoteTextThreadMetadata>,
 48    fs: Arc<dyn Fs>,
 49    languages: Arc<LanguageRegistry>,
 50    slash_commands: Arc<SlashCommandWorkingSet>,
 51    _watch_updates: Task<Option<()>>,
 52    client: Arc<Client>,
 53    project: WeakEntity<Project>,
 54    project_is_shared: bool,
 55    client_subscription: Option<client::Subscription>,
 56    _project_subscriptions: Vec<gpui::Subscription>,
 57    prompt_builder: Arc<PromptBuilder>,
 58}
 59
 60enum TextThreadHandle {
 61    Weak(WeakEntity<TextThread>),
 62    Strong(Entity<TextThread>),
 63}
 64
 65impl TextThreadHandle {
 66    fn upgrade(&self) -> Option<Entity<TextThread>> {
 67        match self {
 68            TextThreadHandle::Weak(weak) => weak.upgrade(),
 69            TextThreadHandle::Strong(strong) => Some(strong.clone()),
 70        }
 71    }
 72
 73    fn downgrade(&self) -> WeakEntity<TextThread> {
 74        match self {
 75            TextThreadHandle::Weak(weak) => weak.clone(),
 76            TextThreadHandle::Strong(strong) => strong.downgrade(),
 77        }
 78    }
 79}
 80
 81impl TextThreadStore {
 82    pub fn new(
 83        project: Entity<Project>,
 84        prompt_builder: Arc<PromptBuilder>,
 85        slash_commands: Arc<SlashCommandWorkingSet>,
 86        cx: &mut App,
 87    ) -> Task<Result<Entity<Self>>> {
 88        let fs = project.read(cx).fs().clone();
 89        let languages = project.read(cx).languages().clone();
 90        cx.spawn(async move |cx| {
 91            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
 92            let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
 93
 94            let this = cx.new(|cx: &mut Context<Self>| {
 95                let mut this = Self {
 96                    text_threads: Vec::new(),
 97                    text_threads_metadata: Vec::new(),
 98                    context_server_slash_command_ids: HashMap::default(),
 99                    host_text_threads: Vec::new(),
100                    fs,
101                    languages,
102                    slash_commands,
103                    _watch_updates: cx.spawn(async move |this, cx| {
104                        async move {
105                            while events.next().await.is_some() {
106                                this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
107                            }
108                            anyhow::Ok(())
109                        }
110                        .log_err()
111                        .await
112                    }),
113                    client_subscription: None,
114                    _project_subscriptions: vec![
115                        cx.subscribe(&project, Self::handle_project_event),
116                    ],
117                    project_is_shared: false,
118                    client: project.read(cx).client(),
119                    project: project.downgrade(),
120                    prompt_builder,
121                };
122                this.handle_project_shared(cx);
123                this.synchronize_contexts(cx);
124                this.register_context_server_handlers(cx);
125                this.reload(cx).detach_and_log_err(cx);
126                this
127            })?;
128
129            Ok(this)
130        })
131    }
132
133    #[cfg(any(test, feature = "test-support"))]
134    pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
135        Self {
136            text_threads: Default::default(),
137            text_threads_metadata: Default::default(),
138            context_server_slash_command_ids: Default::default(),
139            host_text_threads: Default::default(),
140            fs: project.read(cx).fs().clone(),
141            languages: project.read(cx).languages().clone(),
142            slash_commands: Arc::default(),
143            _watch_updates: Task::ready(None),
144            client: project.read(cx).client(),
145            project: project.downgrade(),
146            project_is_shared: false,
147            client_subscription: None,
148            _project_subscriptions: Default::default(),
149            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
150        }
151    }
152
153    async fn handle_advertise_contexts(
154        this: Entity<Self>,
155        envelope: TypedEnvelope<proto::AdvertiseContexts>,
156        mut cx: AsyncApp,
157    ) -> Result<()> {
158        this.update(&mut cx, |this, cx| {
159            this.host_text_threads = envelope
160                .payload
161                .contexts
162                .into_iter()
163                .map(|text_thread| RemoteTextThreadMetadata {
164                    id: TextThreadId::from_proto(text_thread.context_id),
165                    summary: text_thread.summary,
166                })
167                .collect();
168            cx.notify();
169        })
170    }
171
172    async fn handle_open_context(
173        this: Entity<Self>,
174        envelope: TypedEnvelope<proto::OpenContext>,
175        mut cx: AsyncApp,
176    ) -> Result<proto::OpenContextResponse> {
177        let context_id = TextThreadId::from_proto(envelope.payload.context_id);
178        let operations = this.update(&mut cx, |this, cx| {
179            let project = this.project.upgrade().context("project not found")?;
180
181            anyhow::ensure!(
182                !project.read(cx).is_via_collab(),
183                "only the host contexts can be opened"
184            );
185
186            let text_thread = this
187                .loaded_text_thread_for_id(&context_id, cx)
188                .context("context not found")?;
189            anyhow::ensure!(
190                text_thread.read(cx).replica_id() == ReplicaId::default(),
191                "context must be opened via the host"
192            );
193
194            anyhow::Ok(
195                text_thread
196                    .read(cx)
197                    .serialize_ops(&TextThreadVersion::default(), cx),
198            )
199        })??;
200        let operations = operations.await;
201        Ok(proto::OpenContextResponse {
202            context: Some(proto::Context { operations }),
203        })
204    }
205
206    async fn handle_create_context(
207        this: Entity<Self>,
208        _: TypedEnvelope<proto::CreateContext>,
209        mut cx: AsyncApp,
210    ) -> Result<proto::CreateContextResponse> {
211        let (context_id, operations) = this.update(&mut cx, |this, cx| {
212            let project = this.project.upgrade().context("project not found")?;
213            anyhow::ensure!(
214                !project.read(cx).is_via_collab(),
215                "can only create contexts as the host"
216            );
217
218            let text_thread = this.create(cx);
219            let context_id = text_thread.read(cx).id().clone();
220
221            anyhow::Ok((
222                context_id,
223                text_thread
224                    .read(cx)
225                    .serialize_ops(&TextThreadVersion::default(), cx),
226            ))
227        })??;
228        let operations = operations.await;
229        Ok(proto::CreateContextResponse {
230            context_id: context_id.to_proto(),
231            context: Some(proto::Context { operations }),
232        })
233    }
234
235    async fn handle_update_context(
236        this: Entity<Self>,
237        envelope: TypedEnvelope<proto::UpdateContext>,
238        mut cx: AsyncApp,
239    ) -> Result<()> {
240        this.update(&mut cx, |this, cx| {
241            let context_id = TextThreadId::from_proto(envelope.payload.context_id);
242            if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
243                let operation_proto = envelope.payload.operation.context("invalid operation")?;
244                let operation = TextThreadOperation::from_proto(operation_proto)?;
245                text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
246            }
247            Ok(())
248        })?
249    }
250
251    async fn handle_synchronize_contexts(
252        this: Entity<Self>,
253        envelope: TypedEnvelope<proto::SynchronizeContexts>,
254        mut cx: AsyncApp,
255    ) -> Result<proto::SynchronizeContextsResponse> {
256        this.update(&mut cx, |this, cx| {
257            let project = this.project.upgrade().context("project not found")?;
258            anyhow::ensure!(
259                !project.read(cx).is_via_collab(),
260                "only the host can synchronize contexts"
261            );
262
263            let mut local_versions = Vec::new();
264            for remote_version_proto in envelope.payload.contexts {
265                let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
266                let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
267                if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
268                    let text_thread = text_thread.read(cx);
269                    let operations = text_thread.serialize_ops(&remote_version, cx);
270                    local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
271                    let client = this.client.clone();
272                    let project_id = envelope.payload.project_id;
273                    cx.background_spawn(async move {
274                        let operations = operations.await;
275                        for operation in operations {
276                            client.send(proto::UpdateContext {
277                                project_id,
278                                context_id: context_id.to_proto(),
279                                operation: Some(operation),
280                            })?;
281                        }
282                        anyhow::Ok(())
283                    })
284                    .detach_and_log_err(cx);
285                }
286            }
287
288            this.advertise_contexts(cx);
289
290            anyhow::Ok(proto::SynchronizeContextsResponse {
291                contexts: local_versions,
292            })
293        })?
294    }
295
296    fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
297        let Some(project) = self.project.upgrade() else {
298            return;
299        };
300
301        let is_shared = project.read(cx).is_shared();
302        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
303        if is_shared == was_shared {
304            return;
305        }
306
307        if is_shared {
308            self.text_threads.retain_mut(|text_thread| {
309                if let Some(strong_context) = text_thread.upgrade() {
310                    *text_thread = TextThreadHandle::Strong(strong_context);
311                    true
312                } else {
313                    false
314                }
315            });
316            let remote_id = project.read(cx).remote_id().unwrap();
317            self.client_subscription = self
318                .client
319                .subscribe_to_entity(remote_id)
320                .log_err()
321                .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
322            self.advertise_contexts(cx);
323        } else {
324            self.client_subscription = None;
325        }
326    }
327
328    fn handle_project_event(
329        &mut self,
330        _project: Entity<Project>,
331        event: &project::Event,
332        cx: &mut Context<Self>,
333    ) {
334        match event {
335            project::Event::RemoteIdChanged(_) => {
336                self.handle_project_shared(cx);
337            }
338            project::Event::Reshared => {
339                self.advertise_contexts(cx);
340            }
341            project::Event::HostReshared | project::Event::Rejoined => {
342                self.synchronize_contexts(cx);
343            }
344            project::Event::DisconnectedFromHost => {
345                self.text_threads.retain_mut(|text_thread| {
346                    if let Some(strong_context) = text_thread.upgrade() {
347                        *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
348                        strong_context.update(cx, |text_thread, cx| {
349                            if text_thread.replica_id() != ReplicaId::default() {
350                                text_thread.set_capability(language::Capability::ReadOnly, cx);
351                            }
352                        });
353                        true
354                    } else {
355                        false
356                    }
357                });
358                self.host_text_threads.clear();
359                cx.notify();
360            }
361            _ => {}
362        }
363    }
364
365    pub fn unordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
366        self.text_threads_metadata.iter()
367    }
368
369    pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
370        self.host_text_threads.iter()
371    }
372
373    pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
374        let context = cx.new(|cx| {
375            TextThread::local(
376                self.languages.clone(),
377                Some(self.project.clone()),
378                self.prompt_builder.clone(),
379                self.slash_commands.clone(),
380                cx,
381            )
382        });
383        self.register_text_thread(&context, cx);
384        context
385    }
386
387    pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
388        let Some(project) = self.project.upgrade() else {
389            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
390        };
391        let project = project.read(cx);
392        let Some(project_id) = project.remote_id() else {
393            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
394        };
395
396        let replica_id = project.replica_id();
397        let capability = project.capability();
398        let language_registry = self.languages.clone();
399        let project = self.project.clone();
400
401        let prompt_builder = self.prompt_builder.clone();
402        let slash_commands = self.slash_commands.clone();
403        let request = self.client.request(proto::CreateContext { project_id });
404        cx.spawn(async move |this, cx| {
405            let response = request.await?;
406            let context_id = TextThreadId::from_proto(response.context_id);
407            let context_proto = response.context.context("invalid context")?;
408            let text_thread = cx.new(|cx| {
409                TextThread::new(
410                    context_id.clone(),
411                    replica_id,
412                    capability,
413                    language_registry,
414                    prompt_builder,
415                    slash_commands,
416                    Some(project),
417                    cx,
418                )
419            })?;
420            let operations = cx
421                .background_spawn(async move {
422                    context_proto
423                        .operations
424                        .into_iter()
425                        .map(TextThreadOperation::from_proto)
426                        .collect::<Result<Vec<_>>>()
427                })
428                .await?;
429            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
430            this.update(cx, |this, cx| {
431                if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
432                    existing_context
433                } else {
434                    this.register_text_thread(&text_thread, cx);
435                    this.synchronize_contexts(cx);
436                    text_thread
437                }
438            })
439        })
440    }
441
442    pub fn open_local(
443        &mut self,
444        path: Arc<Path>,
445        cx: &Context<Self>,
446    ) -> Task<Result<Entity<TextThread>>> {
447        if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
448            return Task::ready(Ok(existing_context));
449        }
450
451        let fs = self.fs.clone();
452        let languages = self.languages.clone();
453        let project = self.project.clone();
454        let load = cx.background_spawn({
455            let path = path.clone();
456            async move {
457                let saved_context = fs.load(&path).await?;
458                SavedTextThread::from_json(&saved_context)
459            }
460        });
461        let prompt_builder = self.prompt_builder.clone();
462        let slash_commands = self.slash_commands.clone();
463
464        cx.spawn(async move |this, cx| {
465            let saved_context = load.await?;
466            let context = cx.new(|cx| {
467                TextThread::deserialize(
468                    saved_context,
469                    path.clone(),
470                    languages,
471                    prompt_builder,
472                    slash_commands,
473                    Some(project),
474                    cx,
475                )
476            })?;
477            this.update(cx, |this, cx| {
478                if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
479                    existing_context
480                } else {
481                    this.register_text_thread(&context, cx);
482                    context
483                }
484            })
485        })
486    }
487
488    pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
489        let fs = self.fs.clone();
490
491        cx.spawn(async move |this, cx| {
492            fs.remove_file(
493                &path,
494                RemoveOptions {
495                    recursive: false,
496                    ignore_if_not_exists: true,
497                },
498            )
499            .await?;
500
501            this.update(cx, |this, cx| {
502                this.text_threads.retain(|text_thread| {
503                    text_thread
504                        .upgrade()
505                        .and_then(|text_thread| text_thread.read(cx).path())
506                        != Some(&path)
507                });
508                this.text_threads_metadata
509                    .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
510            })?;
511
512            Ok(())
513        })
514    }
515
516    fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
517        self.text_threads.iter().find_map(|text_thread| {
518            let text_thread = text_thread.upgrade()?;
519            if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
520                Some(text_thread)
521            } else {
522                None
523            }
524        })
525    }
526
527    pub fn loaded_text_thread_for_id(
528        &self,
529        id: &TextThreadId,
530        cx: &App,
531    ) -> Option<Entity<TextThread>> {
532        self.text_threads.iter().find_map(|text_thread| {
533            let text_thread = text_thread.upgrade()?;
534            if text_thread.read(cx).id() == id {
535                Some(text_thread)
536            } else {
537                None
538            }
539        })
540    }
541
542    pub fn open_remote(
543        &mut self,
544        text_thread_id: TextThreadId,
545        cx: &mut Context<Self>,
546    ) -> Task<Result<Entity<TextThread>>> {
547        let Some(project) = self.project.upgrade() else {
548            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
549        };
550        let project = project.read(cx);
551        let Some(project_id) = project.remote_id() else {
552            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
553        };
554
555        if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
556            return Task::ready(Ok(context));
557        }
558
559        let replica_id = project.replica_id();
560        let capability = project.capability();
561        let language_registry = self.languages.clone();
562        let project = self.project.clone();
563        let request = self.client.request(proto::OpenContext {
564            project_id,
565            context_id: text_thread_id.to_proto(),
566        });
567        let prompt_builder = self.prompt_builder.clone();
568        let slash_commands = self.slash_commands.clone();
569        cx.spawn(async move |this, cx| {
570            let response = request.await?;
571            let context_proto = response.context.context("invalid context")?;
572            let text_thread = cx.new(|cx| {
573                TextThread::new(
574                    text_thread_id.clone(),
575                    replica_id,
576                    capability,
577                    language_registry,
578                    prompt_builder,
579                    slash_commands,
580                    Some(project),
581                    cx,
582                )
583            })?;
584            let operations = cx
585                .background_spawn(async move {
586                    context_proto
587                        .operations
588                        .into_iter()
589                        .map(TextThreadOperation::from_proto)
590                        .collect::<Result<Vec<_>>>()
591                })
592                .await?;
593            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx))?;
594            this.update(cx, |this, cx| {
595                if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
596                {
597                    existing_context
598                } else {
599                    this.register_text_thread(&text_thread, cx);
600                    this.synchronize_contexts(cx);
601                    text_thread
602                }
603            })
604        })
605    }
606
607    fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
608        let handle = if self.project_is_shared {
609            TextThreadHandle::Strong(text_thread.clone())
610        } else {
611            TextThreadHandle::Weak(text_thread.downgrade())
612        };
613        self.text_threads.push(handle);
614        self.advertise_contexts(cx);
615        cx.subscribe(text_thread, Self::handle_context_event)
616            .detach();
617    }
618
619    fn handle_context_event(
620        &mut self,
621        text_thread: Entity<TextThread>,
622        event: &TextThreadEvent,
623        cx: &mut Context<Self>,
624    ) {
625        let Some(project) = self.project.upgrade() else {
626            return;
627        };
628        let Some(project_id) = project.read(cx).remote_id() else {
629            return;
630        };
631
632        match event {
633            TextThreadEvent::SummaryChanged => {
634                self.advertise_contexts(cx);
635            }
636            TextThreadEvent::PathChanged { old_path, new_path } => {
637                if let Some(old_path) = old_path.as_ref() {
638                    for metadata in &mut self.text_threads_metadata {
639                        if &metadata.path == old_path {
640                            metadata.path = new_path.clone();
641                            break;
642                        }
643                    }
644                }
645            }
646            TextThreadEvent::Operation(operation) => {
647                let context_id = text_thread.read(cx).id().to_proto();
648                let operation = operation.to_proto();
649                self.client
650                    .send(proto::UpdateContext {
651                        project_id,
652                        context_id,
653                        operation: Some(operation),
654                    })
655                    .log_err();
656            }
657            _ => {}
658        }
659    }
660
661    fn advertise_contexts(&self, cx: &App) {
662        let Some(project) = self.project.upgrade() else {
663            return;
664        };
665        let Some(project_id) = project.read(cx).remote_id() else {
666            return;
667        };
668        // For now, only the host can advertise their open contexts.
669        if project.read(cx).is_via_collab() {
670            return;
671        }
672
673        let contexts = self
674            .text_threads
675            .iter()
676            .rev()
677            .filter_map(|text_thread| {
678                let text_thread = text_thread.upgrade()?.read(cx);
679                if text_thread.replica_id() == ReplicaId::default() {
680                    Some(proto::ContextMetadata {
681                        context_id: text_thread.id().to_proto(),
682                        summary: text_thread
683                            .summary()
684                            .content()
685                            .map(|summary| summary.text.clone()),
686                    })
687                } else {
688                    None
689                }
690            })
691            .collect();
692        self.client
693            .send(proto::AdvertiseContexts {
694                project_id,
695                contexts,
696            })
697            .ok();
698    }
699
700    fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
701        let Some(project) = self.project.upgrade() else {
702            return;
703        };
704        let Some(project_id) = project.read(cx).remote_id() else {
705            return;
706        };
707
708        let text_threads = self
709            .text_threads
710            .iter()
711            .filter_map(|text_thread| {
712                let text_thread = text_thread.upgrade()?.read(cx);
713                if text_thread.replica_id() != ReplicaId::default() {
714                    Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
715                } else {
716                    None
717                }
718            })
719            .collect();
720
721        let client = self.client.clone();
722        let request = self.client.request(proto::SynchronizeContexts {
723            project_id,
724            contexts: text_threads,
725        });
726        cx.spawn(async move |this, cx| {
727            let response = request.await?;
728
729            let mut text_thread_ids = Vec::new();
730            let mut operations = Vec::new();
731            this.read_with(cx, |this, cx| {
732                for context_version_proto in response.contexts {
733                    let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
734                    let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
735                    if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
736                        text_thread_ids.push(text_thread_id);
737                        operations
738                            .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
739                    }
740                }
741            })?;
742
743            let operations = futures::future::join_all(operations).await;
744            for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
745                for operation in operations {
746                    client.send(proto::UpdateContext {
747                        project_id,
748                        context_id: context_id.to_proto(),
749                        operation: Some(operation),
750                    })?;
751                }
752            }
753
754            anyhow::Ok(())
755        })
756        .detach_and_log_err(cx);
757    }
758
759    pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
760        let metadata = self.text_threads_metadata.clone();
761        let executor = cx.background_executor().clone();
762        cx.background_spawn(async move {
763            if query.is_empty() {
764                metadata
765            } else {
766                let candidates = metadata
767                    .iter()
768                    .enumerate()
769                    .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
770                    .collect::<Vec<_>>();
771                let matches = fuzzy::match_strings(
772                    &candidates,
773                    &query,
774                    false,
775                    true,
776                    100,
777                    &Default::default(),
778                    executor,
779                )
780                .await;
781
782                matches
783                    .into_iter()
784                    .map(|mat| metadata[mat.candidate_id].clone())
785                    .collect()
786            }
787        })
788    }
789
790    fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
791        let fs = self.fs.clone();
792        cx.spawn(async move |this, cx| {
793            if *ZED_STATELESS {
794                return Ok(());
795            }
796            fs.create_dir(text_threads_dir()).await?;
797
798            let mut paths = fs.read_dir(text_threads_dir()).await?;
799            let mut contexts = Vec::<SavedTextThreadMetadata>::new();
800            while let Some(path) = paths.next().await {
801                let path = path?;
802                if path.extension() != Some(OsStr::new("json")) {
803                    continue;
804                }
805
806                static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
807                    LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
808
809                let metadata = fs.metadata(&path).await?;
810                if let Some((file_name, metadata)) = path
811                    .file_name()
812                    .and_then(|name| name.to_str())
813                    .zip(metadata)
814                {
815                    // This is used to filter out contexts saved by the new assistant.
816                    if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
817                        continue;
818                    }
819
820                    if let Some(title) = ASSISTANT_CONTEXT_REGEX
821                        .replace(file_name, "")
822                        .lines()
823                        .next()
824                    {
825                        contexts.push(SavedTextThreadMetadata {
826                            title: title.to_string().into(),
827                            path: path.into(),
828                            mtime: metadata.mtime.timestamp_for_user().into(),
829                        });
830                    }
831                }
832            }
833            contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
834
835            this.update(cx, |this, cx| {
836                this.text_threads_metadata = contexts;
837                cx.notify();
838            })
839        })
840    }
841
842    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
843        let Some(project) = self.project.upgrade() else {
844            return;
845        };
846        let context_server_store = project.read(cx).context_server_store();
847        cx.subscribe(&context_server_store, Self::handle_context_server_event)
848            .detach();
849
850        // Check for any servers that were already running before the handler was registered
851        for server in context_server_store.read(cx).running_servers() {
852            self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
853        }
854    }
855
856    fn handle_context_server_event(
857        &mut self,
858        context_server_store: Entity<ContextServerStore>,
859        event: &project::context_server_store::Event,
860        cx: &mut Context<Self>,
861    ) {
862        match event {
863            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
864                match status {
865                    ContextServerStatus::Running => {
866                        self.load_context_server_slash_commands(
867                            server_id.clone(),
868                            context_server_store,
869                            cx,
870                        );
871                    }
872                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
873                        if let Some(slash_command_ids) =
874                            self.context_server_slash_command_ids.remove(server_id)
875                        {
876                            self.slash_commands.remove(&slash_command_ids);
877                        }
878                    }
879                    _ => {}
880                }
881            }
882        }
883    }
884
885    fn load_context_server_slash_commands(
886        &self,
887        server_id: ContextServerId,
888        context_server_store: Entity<ContextServerStore>,
889        cx: &mut Context<Self>,
890    ) {
891        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
892            return;
893        };
894        let slash_command_working_set = self.slash_commands.clone();
895        cx.spawn(async move |this, cx| {
896            let Some(protocol) = server.client() else {
897                return;
898            };
899
900            if protocol.capable(context_server::protocol::ServerCapability::Prompts)
901                && let Some(response) = protocol
902                    .request::<context_server::types::requests::PromptsList>(())
903                    .await
904                    .log_err()
905            {
906                let slash_command_ids = response
907                    .prompts
908                    .into_iter()
909                    .filter(assistant_slash_commands::acceptable_prompt)
910                    .map(|prompt| {
911                        log::info!("registering context server command: {:?}", prompt.name);
912                        slash_command_working_set.insert(Arc::new(
913                            assistant_slash_commands::ContextServerSlashCommand::new(
914                                context_server_store.clone(),
915                                server.id(),
916                                prompt,
917                            ),
918                        ))
919                    })
920                    .collect::<Vec<_>>();
921
922                this.update(cx, |this, _cx| {
923                    this.context_server_slash_command_ids
924                        .insert(server_id.clone(), slash_command_ids);
925                })
926                .log_err();
927            }
928        })
929        .detach();
930    }
931}