context_store.rs

  1use crate::slash_command::context_server_command;
  2use crate::{
  3    prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
  4    ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
  5};
  6use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet};
  7use anyhow::{anyhow, Context as _, Result};
  8use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
  9use clock::ReplicaId;
 10use collections::HashMap;
 11use command_palette_hooks::CommandPaletteFilter;
 12use context_servers::manager::{ContextServerManager, ContextServerSettings};
 13use context_servers::CONTEXT_SERVERS_NAMESPACE;
 14use fs::Fs;
 15use futures::StreamExt;
 16use fuzzy::StringMatchCandidate;
 17use gpui::{
 18    AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
 19};
 20use language::LanguageRegistry;
 21use paths::contexts_dir;
 22use project::Project;
 23use regex::Regex;
 24use rpc::AnyProtoClient;
 25use settings::{Settings as _, SettingsStore};
 26use std::{
 27    cmp::Reverse,
 28    ffi::OsStr,
 29    mem,
 30    path::{Path, PathBuf},
 31    sync::Arc,
 32    time::Duration,
 33};
 34use util::{ResultExt, TryFutureExt};
 35
 36pub fn init(client: &AnyProtoClient) {
 37    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 38    client.add_model_request_handler(ContextStore::handle_open_context);
 39    client.add_model_request_handler(ContextStore::handle_create_context);
 40    client.add_model_message_handler(ContextStore::handle_update_context);
 41    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 42}
 43
 44#[derive(Clone)]
 45pub struct RemoteContextMetadata {
 46    pub id: ContextId,
 47    pub summary: Option<String>,
 48}
 49
 50pub struct ContextStore {
 51    contexts: Vec<ContextHandle>,
 52    contexts_metadata: Vec<SavedContextMetadata>,
 53    context_server_manager: Model<ContextServerManager>,
 54    context_server_slash_command_ids: HashMap<String, Vec<SlashCommandId>>,
 55    context_server_tool_ids: HashMap<String, Vec<ToolId>>,
 56    host_contexts: Vec<RemoteContextMetadata>,
 57    fs: Arc<dyn Fs>,
 58    languages: Arc<LanguageRegistry>,
 59    slash_commands: Arc<SlashCommandWorkingSet>,
 60    tools: Arc<ToolWorkingSet>,
 61    telemetry: Arc<Telemetry>,
 62    _watch_updates: Task<Option<()>>,
 63    client: Arc<Client>,
 64    project: Model<Project>,
 65    project_is_shared: bool,
 66    client_subscription: Option<client::Subscription>,
 67    _project_subscriptions: Vec<gpui::Subscription>,
 68    prompt_builder: Arc<PromptBuilder>,
 69}
 70
 71pub enum ContextStoreEvent {
 72    ContextCreated(ContextId),
 73}
 74
 75impl EventEmitter<ContextStoreEvent> for ContextStore {}
 76
 77enum ContextHandle {
 78    Weak(WeakModel<Context>),
 79    Strong(Model<Context>),
 80}
 81
 82impl ContextHandle {
 83    fn upgrade(&self) -> Option<Model<Context>> {
 84        match self {
 85            ContextHandle::Weak(weak) => weak.upgrade(),
 86            ContextHandle::Strong(strong) => Some(strong.clone()),
 87        }
 88    }
 89
 90    fn downgrade(&self) -> WeakModel<Context> {
 91        match self {
 92            ContextHandle::Weak(weak) => weak.clone(),
 93            ContextHandle::Strong(strong) => strong.downgrade(),
 94        }
 95    }
 96}
 97
 98impl ContextStore {
 99    pub fn new(
100        project: Model<Project>,
101        prompt_builder: Arc<PromptBuilder>,
102        slash_commands: Arc<SlashCommandWorkingSet>,
103        tools: Arc<ToolWorkingSet>,
104        cx: &mut AppContext,
105    ) -> Task<Result<Model<Self>>> {
106        let fs = project.read(cx).fs().clone();
107        let languages = project.read(cx).languages().clone();
108        let telemetry = project.read(cx).client().telemetry().clone();
109        cx.spawn(|mut cx| async move {
110            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
111            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
112
113            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
114                let context_server_manager = cx.new_model(|_cx| ContextServerManager::new());
115                let mut this = Self {
116                    contexts: Vec::new(),
117                    contexts_metadata: Vec::new(),
118                    context_server_manager,
119                    context_server_slash_command_ids: HashMap::default(),
120                    context_server_tool_ids: HashMap::default(),
121                    host_contexts: Vec::new(),
122                    fs,
123                    languages,
124                    slash_commands,
125                    tools,
126                    telemetry,
127                    _watch_updates: cx.spawn(|this, mut cx| {
128                        async move {
129                            while events.next().await.is_some() {
130                                this.update(&mut cx, |this, cx| this.reload(cx))?
131                                    .await
132                                    .log_err();
133                            }
134                            anyhow::Ok(())
135                        }
136                        .log_err()
137                    }),
138                    client_subscription: None,
139                    _project_subscriptions: vec![
140                        cx.observe(&project, Self::handle_project_changed),
141                        cx.subscribe(&project, Self::handle_project_event),
142                    ],
143                    project_is_shared: false,
144                    client: project.read(cx).client(),
145                    project: project.clone(),
146                    prompt_builder,
147                };
148                this.handle_project_changed(project, cx);
149                this.synchronize_contexts(cx);
150                this.register_context_server_handlers(cx);
151                this
152            })?;
153            this.update(&mut cx, |this, cx| this.reload(cx))?
154                .await
155                .log_err();
156
157            this.update(&mut cx, |this, cx| {
158                this.watch_context_server_settings(cx);
159            })
160            .log_err();
161
162            Ok(this)
163        })
164    }
165
166    fn watch_context_server_settings(&self, cx: &mut ModelContext<Self>) {
167        cx.observe_global::<SettingsStore>(move |this, cx| {
168            this.context_server_manager.update(cx, |manager, cx| {
169                let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
170                    settings::SettingsLocation {
171                        worktree_id: worktree.read(cx).id(),
172                        path: Path::new(""),
173                    }
174                });
175                let settings = ContextServerSettings::get(location, cx);
176
177                manager.maintain_servers(settings, cx);
178
179                let has_any_context_servers = !manager.servers().is_empty();
180                CommandPaletteFilter::update_global(cx, |filter, _cx| {
181                    if has_any_context_servers {
182                        filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
183                    } else {
184                        filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
185                    }
186                });
187            })
188        })
189        .detach();
190    }
191
192    async fn handle_advertise_contexts(
193        this: Model<Self>,
194        envelope: TypedEnvelope<proto::AdvertiseContexts>,
195        mut cx: AsyncAppContext,
196    ) -> Result<()> {
197        this.update(&mut cx, |this, cx| {
198            this.host_contexts = envelope
199                .payload
200                .contexts
201                .into_iter()
202                .map(|context| RemoteContextMetadata {
203                    id: ContextId::from_proto(context.context_id),
204                    summary: context.summary,
205                })
206                .collect();
207            cx.notify();
208        })
209    }
210
211    async fn handle_open_context(
212        this: Model<Self>,
213        envelope: TypedEnvelope<proto::OpenContext>,
214        mut cx: AsyncAppContext,
215    ) -> Result<proto::OpenContextResponse> {
216        let context_id = ContextId::from_proto(envelope.payload.context_id);
217        let operations = this.update(&mut cx, |this, cx| {
218            if this.project.read(cx).is_via_collab() {
219                return Err(anyhow!("only the host contexts can be opened"));
220            }
221
222            let context = this
223                .loaded_context_for_id(&context_id, cx)
224                .context("context not found")?;
225            if context.read(cx).replica_id() != ReplicaId::default() {
226                return Err(anyhow!("context must be opened via the host"));
227            }
228
229            anyhow::Ok(
230                context
231                    .read(cx)
232                    .serialize_ops(&ContextVersion::default(), cx),
233            )
234        })??;
235        let operations = operations.await;
236        Ok(proto::OpenContextResponse {
237            context: Some(proto::Context { operations }),
238        })
239    }
240
241    async fn handle_create_context(
242        this: Model<Self>,
243        _: TypedEnvelope<proto::CreateContext>,
244        mut cx: AsyncAppContext,
245    ) -> Result<proto::CreateContextResponse> {
246        let (context_id, operations) = this.update(&mut cx, |this, cx| {
247            if this.project.read(cx).is_via_collab() {
248                return Err(anyhow!("can only create contexts as the host"));
249            }
250
251            let context = this.create(cx);
252            let context_id = context.read(cx).id().clone();
253            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
254
255            anyhow::Ok((
256                context_id,
257                context
258                    .read(cx)
259                    .serialize_ops(&ContextVersion::default(), cx),
260            ))
261        })??;
262        let operations = operations.await;
263        Ok(proto::CreateContextResponse {
264            context_id: context_id.to_proto(),
265            context: Some(proto::Context { operations }),
266        })
267    }
268
269    async fn handle_update_context(
270        this: Model<Self>,
271        envelope: TypedEnvelope<proto::UpdateContext>,
272        mut cx: AsyncAppContext,
273    ) -> Result<()> {
274        this.update(&mut cx, |this, cx| {
275            let context_id = ContextId::from_proto(envelope.payload.context_id);
276            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
277                let operation_proto = envelope.payload.operation.context("invalid operation")?;
278                let operation = ContextOperation::from_proto(operation_proto)?;
279                context.update(cx, |context, cx| context.apply_ops([operation], cx));
280            }
281            Ok(())
282        })?
283    }
284
285    async fn handle_synchronize_contexts(
286        this: Model<Self>,
287        envelope: TypedEnvelope<proto::SynchronizeContexts>,
288        mut cx: AsyncAppContext,
289    ) -> Result<proto::SynchronizeContextsResponse> {
290        this.update(&mut cx, |this, cx| {
291            if this.project.read(cx).is_via_collab() {
292                return Err(anyhow!("only the host can synchronize contexts"));
293            }
294
295            let mut local_versions = Vec::new();
296            for remote_version_proto in envelope.payload.contexts {
297                let remote_version = ContextVersion::from_proto(&remote_version_proto);
298                let context_id = ContextId::from_proto(remote_version_proto.context_id);
299                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
300                    let context = context.read(cx);
301                    let operations = context.serialize_ops(&remote_version, cx);
302                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
303                    let client = this.client.clone();
304                    let project_id = envelope.payload.project_id;
305                    cx.background_executor()
306                        .spawn(async move {
307                            let operations = operations.await;
308                            for operation in operations {
309                                client.send(proto::UpdateContext {
310                                    project_id,
311                                    context_id: context_id.to_proto(),
312                                    operation: Some(operation),
313                                })?;
314                            }
315                            anyhow::Ok(())
316                        })
317                        .detach_and_log_err(cx);
318                }
319            }
320
321            this.advertise_contexts(cx);
322
323            anyhow::Ok(proto::SynchronizeContextsResponse {
324                contexts: local_versions,
325            })
326        })?
327    }
328
329    fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
330        let is_shared = self.project.read(cx).is_shared();
331        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
332        if is_shared == was_shared {
333            return;
334        }
335
336        if is_shared {
337            self.contexts.retain_mut(|context| {
338                if let Some(strong_context) = context.upgrade() {
339                    *context = ContextHandle::Strong(strong_context);
340                    true
341                } else {
342                    false
343                }
344            });
345            let remote_id = self.project.read(cx).remote_id().unwrap();
346            self.client_subscription = self
347                .client
348                .subscribe_to_entity(remote_id)
349                .log_err()
350                .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
351            self.advertise_contexts(cx);
352        } else {
353            self.client_subscription = None;
354        }
355    }
356
357    fn handle_project_event(
358        &mut self,
359        _: Model<Project>,
360        event: &project::Event,
361        cx: &mut ModelContext<Self>,
362    ) {
363        match event {
364            project::Event::Reshared => {
365                self.advertise_contexts(cx);
366            }
367            project::Event::HostReshared | project::Event::Rejoined => {
368                self.synchronize_contexts(cx);
369            }
370            project::Event::DisconnectedFromHost => {
371                self.contexts.retain_mut(|context| {
372                    if let Some(strong_context) = context.upgrade() {
373                        *context = ContextHandle::Weak(context.downgrade());
374                        strong_context.update(cx, |context, cx| {
375                            if context.replica_id() != ReplicaId::default() {
376                                context.set_capability(language::Capability::ReadOnly, cx);
377                            }
378                        });
379                        true
380                    } else {
381                        false
382                    }
383                });
384                self.host_contexts.clear();
385                cx.notify();
386            }
387            _ => {}
388        }
389    }
390
391    pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
392        let context = cx.new_model(|cx| {
393            Context::local(
394                self.languages.clone(),
395                Some(self.project.clone()),
396                Some(self.telemetry.clone()),
397                self.prompt_builder.clone(),
398                self.slash_commands.clone(),
399                self.tools.clone(),
400                cx,
401            )
402        });
403        self.register_context(&context, cx);
404        context
405    }
406
407    pub fn create_remote_context(
408        &mut self,
409        cx: &mut ModelContext<Self>,
410    ) -> Task<Result<Model<Context>>> {
411        let project = self.project.read(cx);
412        let Some(project_id) = project.remote_id() else {
413            return Task::ready(Err(anyhow!("project was not remote")));
414        };
415
416        let replica_id = project.replica_id();
417        let capability = project.capability();
418        let language_registry = self.languages.clone();
419        let project = self.project.clone();
420        let telemetry = self.telemetry.clone();
421        let prompt_builder = self.prompt_builder.clone();
422        let slash_commands = self.slash_commands.clone();
423        let tools = self.tools.clone();
424        let request = self.client.request(proto::CreateContext { project_id });
425        cx.spawn(|this, mut cx| async move {
426            let response = request.await?;
427            let context_id = ContextId::from_proto(response.context_id);
428            let context_proto = response.context.context("invalid context")?;
429            let context = cx.new_model(|cx| {
430                Context::new(
431                    context_id.clone(),
432                    replica_id,
433                    capability,
434                    language_registry,
435                    prompt_builder,
436                    slash_commands,
437                    tools,
438                    Some(project),
439                    Some(telemetry),
440                    cx,
441                )
442            })?;
443            let operations = cx
444                .background_executor()
445                .spawn(async move {
446                    context_proto
447                        .operations
448                        .into_iter()
449                        .map(ContextOperation::from_proto)
450                        .collect::<Result<Vec<_>>>()
451                })
452                .await?;
453            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
454            this.update(&mut cx, |this, cx| {
455                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
456                    existing_context
457                } else {
458                    this.register_context(&context, cx);
459                    this.synchronize_contexts(cx);
460                    context
461                }
462            })
463        })
464    }
465
466    pub fn open_local_context(
467        &mut self,
468        path: PathBuf,
469        cx: &ModelContext<Self>,
470    ) -> Task<Result<Model<Context>>> {
471        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
472            return Task::ready(Ok(existing_context));
473        }
474
475        let fs = self.fs.clone();
476        let languages = self.languages.clone();
477        let project = self.project.clone();
478        let telemetry = self.telemetry.clone();
479        let load = cx.background_executor().spawn({
480            let path = path.clone();
481            async move {
482                let saved_context = fs.load(&path).await?;
483                SavedContext::from_json(&saved_context)
484            }
485        });
486        let prompt_builder = self.prompt_builder.clone();
487        let slash_commands = self.slash_commands.clone();
488        let tools = self.tools.clone();
489
490        cx.spawn(|this, mut cx| async move {
491            let saved_context = load.await?;
492            let context = cx.new_model(|cx| {
493                Context::deserialize(
494                    saved_context,
495                    path.clone(),
496                    languages,
497                    prompt_builder,
498                    slash_commands,
499                    tools,
500                    Some(project),
501                    Some(telemetry),
502                    cx,
503                )
504            })?;
505            this.update(&mut cx, |this, cx| {
506                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
507                    existing_context
508                } else {
509                    this.register_context(&context, cx);
510                    context
511                }
512            })
513        })
514    }
515
516    fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
517        self.contexts.iter().find_map(|context| {
518            let context = context.upgrade()?;
519            if context.read(cx).path() == Some(path) {
520                Some(context)
521            } else {
522                None
523            }
524        })
525    }
526
527    pub(super) fn loaded_context_for_id(
528        &self,
529        id: &ContextId,
530        cx: &AppContext,
531    ) -> Option<Model<Context>> {
532        self.contexts.iter().find_map(|context| {
533            let context = context.upgrade()?;
534            if context.read(cx).id() == id {
535                Some(context)
536            } else {
537                None
538            }
539        })
540    }
541
542    pub fn open_remote_context(
543        &mut self,
544        context_id: ContextId,
545        cx: &mut ModelContext<Self>,
546    ) -> Task<Result<Model<Context>>> {
547        let project = self.project.read(cx);
548        let Some(project_id) = project.remote_id() else {
549            return Task::ready(Err(anyhow!("project was not remote")));
550        };
551
552        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
553            return Task::ready(Ok(context));
554        }
555
556        let replica_id = project.replica_id();
557        let capability = project.capability();
558        let language_registry = self.languages.clone();
559        let project = self.project.clone();
560        let telemetry = self.telemetry.clone();
561        let request = self.client.request(proto::OpenContext {
562            project_id,
563            context_id: context_id.to_proto(),
564        });
565        let prompt_builder = self.prompt_builder.clone();
566        let slash_commands = self.slash_commands.clone();
567        let tools = self.tools.clone();
568        cx.spawn(|this, mut cx| async move {
569            let response = request.await?;
570            let context_proto = response.context.context("invalid context")?;
571            let context = cx.new_model(|cx| {
572                Context::new(
573                    context_id.clone(),
574                    replica_id,
575                    capability,
576                    language_registry,
577                    prompt_builder,
578                    slash_commands,
579                    tools,
580                    Some(project),
581                    Some(telemetry),
582                    cx,
583                )
584            })?;
585            let operations = cx
586                .background_executor()
587                .spawn(async move {
588                    context_proto
589                        .operations
590                        .into_iter()
591                        .map(ContextOperation::from_proto)
592                        .collect::<Result<Vec<_>>>()
593                })
594                .await?;
595            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
596            this.update(&mut cx, |this, cx| {
597                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
598                    existing_context
599                } else {
600                    this.register_context(&context, cx);
601                    this.synchronize_contexts(cx);
602                    context
603                }
604            })
605        })
606    }
607
608    fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
609        let handle = if self.project_is_shared {
610            ContextHandle::Strong(context.clone())
611        } else {
612            ContextHandle::Weak(context.downgrade())
613        };
614        self.contexts.push(handle);
615        self.advertise_contexts(cx);
616        cx.subscribe(context, Self::handle_context_event).detach();
617    }
618
619    fn handle_context_event(
620        &mut self,
621        context: Model<Context>,
622        event: &ContextEvent,
623        cx: &mut ModelContext<Self>,
624    ) {
625        let Some(project_id) = self.project.read(cx).remote_id() else {
626            return;
627        };
628
629        match event {
630            ContextEvent::SummaryChanged => {
631                self.advertise_contexts(cx);
632            }
633            ContextEvent::Operation(operation) => {
634                let context_id = context.read(cx).id().to_proto();
635                let operation = operation.to_proto();
636                self.client
637                    .send(proto::UpdateContext {
638                        project_id,
639                        context_id,
640                        operation: Some(operation),
641                    })
642                    .log_err();
643            }
644            _ => {}
645        }
646    }
647
648    fn advertise_contexts(&self, cx: &AppContext) {
649        let Some(project_id) = self.project.read(cx).remote_id() else {
650            return;
651        };
652
653        // For now, only the host can advertise their open contexts.
654        if self.project.read(cx).is_via_collab() {
655            return;
656        }
657
658        let contexts = self
659            .contexts
660            .iter()
661            .rev()
662            .filter_map(|context| {
663                let context = context.upgrade()?.read(cx);
664                if context.replica_id() == ReplicaId::default() {
665                    Some(proto::ContextMetadata {
666                        context_id: context.id().to_proto(),
667                        summary: context.summary().map(|summary| summary.text.clone()),
668                    })
669                } else {
670                    None
671                }
672            })
673            .collect();
674        self.client
675            .send(proto::AdvertiseContexts {
676                project_id,
677                contexts,
678            })
679            .ok();
680    }
681
682    fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
683        let Some(project_id) = self.project.read(cx).remote_id() else {
684            return;
685        };
686
687        let contexts = self
688            .contexts
689            .iter()
690            .filter_map(|context| {
691                let context = context.upgrade()?.read(cx);
692                if context.replica_id() != ReplicaId::default() {
693                    Some(context.version(cx).to_proto(context.id().clone()))
694                } else {
695                    None
696                }
697            })
698            .collect();
699
700        let client = self.client.clone();
701        let request = self.client.request(proto::SynchronizeContexts {
702            project_id,
703            contexts,
704        });
705        cx.spawn(|this, cx| async move {
706            let response = request.await?;
707
708            let mut context_ids = Vec::new();
709            let mut operations = Vec::new();
710            this.read_with(&cx, |this, cx| {
711                for context_version_proto in response.contexts {
712                    let context_version = ContextVersion::from_proto(&context_version_proto);
713                    let context_id = ContextId::from_proto(context_version_proto.context_id);
714                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
715                        context_ids.push(context_id);
716                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
717                    }
718                }
719            })?;
720
721            let operations = futures::future::join_all(operations).await;
722            for (context_id, operations) in context_ids.into_iter().zip(operations) {
723                for operation in operations {
724                    client.send(proto::UpdateContext {
725                        project_id,
726                        context_id: context_id.to_proto(),
727                        operation: Some(operation),
728                    })?;
729                }
730            }
731
732            anyhow::Ok(())
733        })
734        .detach_and_log_err(cx);
735    }
736
737    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
738        let metadata = self.contexts_metadata.clone();
739        let executor = cx.background_executor().clone();
740        cx.background_executor().spawn(async move {
741            if query.is_empty() {
742                metadata
743            } else {
744                let candidates = metadata
745                    .iter()
746                    .enumerate()
747                    .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
748                    .collect::<Vec<_>>();
749                let matches = fuzzy::match_strings(
750                    &candidates,
751                    &query,
752                    false,
753                    100,
754                    &Default::default(),
755                    executor,
756                )
757                .await;
758
759                matches
760                    .into_iter()
761                    .map(|mat| metadata[mat.candidate_id].clone())
762                    .collect()
763            }
764        })
765    }
766
767    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
768        &self.host_contexts
769    }
770
771    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
772        let fs = self.fs.clone();
773        cx.spawn(|this, mut cx| async move {
774            fs.create_dir(contexts_dir()).await?;
775
776            let mut paths = fs.read_dir(contexts_dir()).await?;
777            let mut contexts = Vec::<SavedContextMetadata>::new();
778            while let Some(path) = paths.next().await {
779                let path = path?;
780                if path.extension() != Some(OsStr::new("json")) {
781                    continue;
782                }
783
784                let pattern = r" - \d+.zed.json$";
785                let re = Regex::new(pattern).unwrap();
786
787                let metadata = fs.metadata(&path).await?;
788                if let Some((file_name, metadata)) = path
789                    .file_name()
790                    .and_then(|name| name.to_str())
791                    .zip(metadata)
792                {
793                    // This is used to filter out contexts saved by the new assistant.
794                    if !re.is_match(file_name) {
795                        continue;
796                    }
797
798                    if let Some(title) = re.replace(file_name, "").lines().next() {
799                        contexts.push(SavedContextMetadata {
800                            title: title.to_string(),
801                            path,
802                            mtime: metadata.mtime.into(),
803                        });
804                    }
805                }
806            }
807            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
808
809            this.update(&mut cx, |this, cx| {
810                this.contexts_metadata = contexts;
811                cx.notify();
812            })
813        })
814    }
815
816    pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
817        cx.update_model(
818            &self.context_server_manager,
819            |context_server_manager, cx| {
820                for server in context_server_manager.servers() {
821                    context_server_manager
822                        .restart_server(&server.id, cx)
823                        .detach_and_log_err(cx);
824                }
825            },
826        );
827    }
828
829    fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
830        cx.subscribe(
831            &self.context_server_manager.clone(),
832            Self::handle_context_server_event,
833        )
834        .detach();
835    }
836
837    fn handle_context_server_event(
838        &mut self,
839        context_server_manager: Model<ContextServerManager>,
840        event: &context_servers::manager::Event,
841        cx: &mut ModelContext<Self>,
842    ) {
843        let slash_command_working_set = self.slash_commands.clone();
844        let tool_working_set = self.tools.clone();
845        match event {
846            context_servers::manager::Event::ServerStarted { server_id } => {
847                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
848                    let context_server_manager = context_server_manager.clone();
849                    cx.spawn({
850                        let server = server.clone();
851                        let server_id = server_id.clone();
852                        |this, mut cx| async move {
853                            let Some(protocol) = server.client.read().clone() else {
854                                return;
855                            };
856
857                            if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
858                                if let Some(prompts) = protocol.list_prompts().await.log_err() {
859                                    let slash_command_ids = prompts
860                                        .into_iter()
861                                        .filter(context_server_command::acceptable_prompt)
862                                        .map(|prompt| {
863                                            log::info!(
864                                                "registering context server command: {:?}",
865                                                prompt.name
866                                            );
867                                            slash_command_working_set.insert(Arc::new(
868                                                context_server_command::ContextServerSlashCommand::new(
869                                                    context_server_manager.clone(),
870                                                    &server,
871                                                    prompt,
872                                                ),
873                                            ))
874                                        })
875                                        .collect::<Vec<_>>();
876
877                                    this.update(&mut cx, |this, _cx| {
878                                        this.context_server_slash_command_ids
879                                            .insert(server_id.clone(), slash_command_ids);
880                                    })
881                                    .log_err();
882                                }
883                            }
884
885                            if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
886                                if let Some(tools) = protocol.list_tools().await.log_err() {
887                                    let tool_ids = tools.tools.into_iter().map(|tool| {
888                                        log::info!("registering context server tool: {:?}", tool.name);
889                                        tool_working_set.insert(
890                                            Arc::new(tools::context_server_tool::ContextServerTool::new(
891                                                context_server_manager.clone(),
892                                                server.id.clone(),
893                                                tool,
894                                            )),
895                                        )
896
897                                    }).collect::<Vec<_>>();
898
899
900                                    this.update(&mut cx, |this, _cx| {
901                                        this.context_server_tool_ids
902                                            .insert(server_id, tool_ids);
903                                    })
904                                    .log_err();
905                                }
906                            }
907                        }
908                    })
909                    .detach();
910                }
911            }
912            context_servers::manager::Event::ServerStopped { server_id } => {
913                if let Some(slash_command_ids) =
914                    self.context_server_slash_command_ids.remove(server_id)
915                {
916                    slash_command_working_set.remove(&slash_command_ids);
917                }
918
919                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
920                    tool_working_set.remove(&tool_ids);
921                }
922            }
923        }
924    }
925}