context_store.rs

  1use crate::{
  2    AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
  3    SavedContextMetadata,
  4};
  5use anyhow::{anyhow, Context as _, Result};
  6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
  7use assistant_tool::{ToolId, ToolWorkingSet};
  8use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
  9use clock::ReplicaId;
 10use collections::HashMap;
 11use context_server::manager::ContextServerManager;
 12use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 13use fs::Fs;
 14use futures::StreamExt;
 15use fuzzy::StringMatchCandidate;
 16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
 17use language::LanguageRegistry;
 18use paths::contexts_dir;
 19use project::Project;
 20use prompt_library::PromptBuilder;
 21use regex::Regex;
 22use rpc::AnyProtoClient;
 23use std::sync::LazyLock;
 24use std::{
 25    cmp::Reverse,
 26    ffi::OsStr,
 27    mem,
 28    path::{Path, PathBuf},
 29    sync::Arc,
 30    time::Duration,
 31};
 32use util::{ResultExt, TryFutureExt};
 33
 34pub(crate) fn init(client: &AnyProtoClient) {
 35    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 36    client.add_model_request_handler(ContextStore::handle_open_context);
 37    client.add_model_request_handler(ContextStore::handle_create_context);
 38    client.add_model_message_handler(ContextStore::handle_update_context);
 39    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 40}
 41
 42#[derive(Clone)]
 43pub struct RemoteContextMetadata {
 44    pub id: ContextId,
 45    pub summary: Option<String>,
 46}
 47
 48pub struct ContextStore {
 49    contexts: Vec<ContextHandle>,
 50    contexts_metadata: Vec<SavedContextMetadata>,
 51    context_server_manager: Entity<ContextServerManager>,
 52    context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
 53    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 54    host_contexts: Vec<RemoteContextMetadata>,
 55    fs: Arc<dyn Fs>,
 56    languages: Arc<LanguageRegistry>,
 57    slash_commands: Arc<SlashCommandWorkingSet>,
 58    tools: Arc<ToolWorkingSet>,
 59    telemetry: Arc<Telemetry>,
 60    _watch_updates: Task<Option<()>>,
 61    client: Arc<Client>,
 62    project: Entity<Project>,
 63    project_is_shared: bool,
 64    client_subscription: Option<client::Subscription>,
 65    _project_subscriptions: Vec<gpui::Subscription>,
 66    prompt_builder: Arc<PromptBuilder>,
 67}
 68
 69pub enum ContextStoreEvent {
 70    ContextCreated(ContextId),
 71}
 72
 73impl EventEmitter<ContextStoreEvent> for ContextStore {}
 74
 75enum ContextHandle {
 76    Weak(WeakEntity<AssistantContext>),
 77    Strong(Entity<AssistantContext>),
 78}
 79
 80impl ContextHandle {
 81    fn upgrade(&self) -> Option<Entity<AssistantContext>> {
 82        match self {
 83            ContextHandle::Weak(weak) => weak.upgrade(),
 84            ContextHandle::Strong(strong) => Some(strong.clone()),
 85        }
 86    }
 87
 88    fn downgrade(&self) -> WeakEntity<AssistantContext> {
 89        match self {
 90            ContextHandle::Weak(weak) => weak.clone(),
 91            ContextHandle::Strong(strong) => strong.downgrade(),
 92        }
 93    }
 94}
 95
 96impl ContextStore {
 97    pub fn new(
 98        project: Entity<Project>,
 99        prompt_builder: Arc<PromptBuilder>,
100        slash_commands: Arc<SlashCommandWorkingSet>,
101        tools: Arc<ToolWorkingSet>,
102        cx: &mut App,
103    ) -> Task<Result<Entity<Self>>> {
104        let fs = project.read(cx).fs().clone();
105        let languages = project.read(cx).languages().clone();
106        let telemetry = project.read(cx).client().telemetry().clone();
107        cx.spawn(|mut cx| async move {
108            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
109            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
110
111            let this = cx.new(|cx: &mut Context<Self>| {
112                let context_server_factory_registry =
113                    ContextServerFactoryRegistry::default_global(cx);
114                let context_server_manager = cx.new(|cx| {
115                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
116                });
117                let mut this = Self {
118                    contexts: Vec::new(),
119                    contexts_metadata: Vec::new(),
120                    context_server_manager,
121                    context_server_slash_command_ids: HashMap::default(),
122                    context_server_tool_ids: HashMap::default(),
123                    host_contexts: Vec::new(),
124                    fs,
125                    languages,
126                    slash_commands,
127                    tools,
128                    telemetry,
129                    _watch_updates: cx.spawn(|this, mut cx| {
130                        async move {
131                            while events.next().await.is_some() {
132                                this.update(&mut cx, |this, cx| this.reload(cx))?
133                                    .await
134                                    .log_err();
135                            }
136                            anyhow::Ok(())
137                        }
138                        .log_err()
139                    }),
140                    client_subscription: None,
141                    _project_subscriptions: vec![
142                        cx.observe(&project, Self::handle_project_changed),
143                        cx.subscribe(&project, Self::handle_project_event),
144                    ],
145                    project_is_shared: false,
146                    client: project.read(cx).client(),
147                    project: project.clone(),
148                    prompt_builder,
149                };
150                this.handle_project_changed(project.clone(), cx);
151                this.synchronize_contexts(cx);
152                this.register_context_server_handlers(cx);
153                this
154            })?;
155            this.update(&mut cx, |this, cx| this.reload(cx))?
156                .await
157                .log_err();
158
159            Ok(this)
160        })
161    }
162
163    async fn handle_advertise_contexts(
164        this: Entity<Self>,
165        envelope: TypedEnvelope<proto::AdvertiseContexts>,
166        mut cx: AsyncApp,
167    ) -> Result<()> {
168        this.update(&mut cx, |this, cx| {
169            this.host_contexts = envelope
170                .payload
171                .contexts
172                .into_iter()
173                .map(|context| RemoteContextMetadata {
174                    id: ContextId::from_proto(context.context_id),
175                    summary: context.summary,
176                })
177                .collect();
178            cx.notify();
179        })
180    }
181
182    async fn handle_open_context(
183        this: Entity<Self>,
184        envelope: TypedEnvelope<proto::OpenContext>,
185        mut cx: AsyncApp,
186    ) -> Result<proto::OpenContextResponse> {
187        let context_id = ContextId::from_proto(envelope.payload.context_id);
188        let operations = this.update(&mut cx, |this, cx| {
189            if this.project.read(cx).is_via_collab() {
190                return Err(anyhow!("only the host contexts can be opened"));
191            }
192
193            let context = this
194                .loaded_context_for_id(&context_id, cx)
195                .context("context not found")?;
196            if context.read(cx).replica_id() != ReplicaId::default() {
197                return Err(anyhow!("context must be opened via the host"));
198            }
199
200            anyhow::Ok(
201                context
202                    .read(cx)
203                    .serialize_ops(&ContextVersion::default(), cx),
204            )
205        })??;
206        let operations = operations.await;
207        Ok(proto::OpenContextResponse {
208            context: Some(proto::Context { operations }),
209        })
210    }
211
212    async fn handle_create_context(
213        this: Entity<Self>,
214        _: TypedEnvelope<proto::CreateContext>,
215        mut cx: AsyncApp,
216    ) -> Result<proto::CreateContextResponse> {
217        let (context_id, operations) = this.update(&mut cx, |this, cx| {
218            if this.project.read(cx).is_via_collab() {
219                return Err(anyhow!("can only create contexts as the host"));
220            }
221
222            let context = this.create(cx);
223            let context_id = context.read(cx).id().clone();
224            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
225
226            anyhow::Ok((
227                context_id,
228                context
229                    .read(cx)
230                    .serialize_ops(&ContextVersion::default(), cx),
231            ))
232        })??;
233        let operations = operations.await;
234        Ok(proto::CreateContextResponse {
235            context_id: context_id.to_proto(),
236            context: Some(proto::Context { operations }),
237        })
238    }
239
240    async fn handle_update_context(
241        this: Entity<Self>,
242        envelope: TypedEnvelope<proto::UpdateContext>,
243        mut cx: AsyncApp,
244    ) -> Result<()> {
245        this.update(&mut cx, |this, cx| {
246            let context_id = ContextId::from_proto(envelope.payload.context_id);
247            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
248                let operation_proto = envelope.payload.operation.context("invalid operation")?;
249                let operation = ContextOperation::from_proto(operation_proto)?;
250                context.update(cx, |context, cx| context.apply_ops([operation], cx));
251            }
252            Ok(())
253        })?
254    }
255
256    async fn handle_synchronize_contexts(
257        this: Entity<Self>,
258        envelope: TypedEnvelope<proto::SynchronizeContexts>,
259        mut cx: AsyncApp,
260    ) -> Result<proto::SynchronizeContextsResponse> {
261        this.update(&mut cx, |this, cx| {
262            if this.project.read(cx).is_via_collab() {
263                return Err(anyhow!("only the host can synchronize contexts"));
264            }
265
266            let mut local_versions = Vec::new();
267            for remote_version_proto in envelope.payload.contexts {
268                let remote_version = ContextVersion::from_proto(&remote_version_proto);
269                let context_id = ContextId::from_proto(remote_version_proto.context_id);
270                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
271                    let context = context.read(cx);
272                    let operations = context.serialize_ops(&remote_version, cx);
273                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
274                    let client = this.client.clone();
275                    let project_id = envelope.payload.project_id;
276                    cx.background_executor()
277                        .spawn(async move {
278                            let operations = operations.await;
279                            for operation in operations {
280                                client.send(proto::UpdateContext {
281                                    project_id,
282                                    context_id: context_id.to_proto(),
283                                    operation: Some(operation),
284                                })?;
285                            }
286                            anyhow::Ok(())
287                        })
288                        .detach_and_log_err(cx);
289                }
290            }
291
292            this.advertise_contexts(cx);
293
294            anyhow::Ok(proto::SynchronizeContextsResponse {
295                contexts: local_versions,
296            })
297        })?
298    }
299
300    fn handle_project_changed(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
301        let is_shared = self.project.read(cx).is_shared();
302        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
303        if is_shared == was_shared {
304            return;
305        }
306
307        if is_shared {
308            self.contexts.retain_mut(|context| {
309                if let Some(strong_context) = context.upgrade() {
310                    *context = ContextHandle::Strong(strong_context);
311                    true
312                } else {
313                    false
314                }
315            });
316            let remote_id = self.project.read(cx).remote_id().unwrap();
317            self.client_subscription = self
318                .client
319                .subscribe_to_entity(remote_id)
320                .log_err()
321                .map(|subscription| subscription.set_model(&cx.entity(), &mut cx.to_async()));
322            self.advertise_contexts(cx);
323        } else {
324            self.client_subscription = None;
325        }
326    }
327
328    fn handle_project_event(
329        &mut self,
330        _: Entity<Project>,
331        event: &project::Event,
332        cx: &mut Context<Self>,
333    ) {
334        match event {
335            project::Event::Reshared => {
336                self.advertise_contexts(cx);
337            }
338            project::Event::HostReshared | project::Event::Rejoined => {
339                self.synchronize_contexts(cx);
340            }
341            project::Event::DisconnectedFromHost => {
342                self.contexts.retain_mut(|context| {
343                    if let Some(strong_context) = context.upgrade() {
344                        *context = ContextHandle::Weak(context.downgrade());
345                        strong_context.update(cx, |context, cx| {
346                            if context.replica_id() != ReplicaId::default() {
347                                context.set_capability(language::Capability::ReadOnly, cx);
348                            }
349                        });
350                        true
351                    } else {
352                        false
353                    }
354                });
355                self.host_contexts.clear();
356                cx.notify();
357            }
358            _ => {}
359        }
360    }
361
362    pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
363        let context = cx.new(|cx| {
364            AssistantContext::local(
365                self.languages.clone(),
366                Some(self.project.clone()),
367                Some(self.telemetry.clone()),
368                self.prompt_builder.clone(),
369                self.slash_commands.clone(),
370                self.tools.clone(),
371                cx,
372            )
373        });
374        self.register_context(&context, cx);
375        context
376    }
377
378    pub fn create_remote_context(
379        &mut self,
380        cx: &mut Context<Self>,
381    ) -> Task<Result<Entity<AssistantContext>>> {
382        let project = self.project.read(cx);
383        let Some(project_id) = project.remote_id() else {
384            return Task::ready(Err(anyhow!("project was not remote")));
385        };
386
387        let replica_id = project.replica_id();
388        let capability = project.capability();
389        let language_registry = self.languages.clone();
390        let project = self.project.clone();
391        let telemetry = self.telemetry.clone();
392        let prompt_builder = self.prompt_builder.clone();
393        let slash_commands = self.slash_commands.clone();
394        let tools = self.tools.clone();
395        let request = self.client.request(proto::CreateContext { project_id });
396        cx.spawn(|this, mut cx| async move {
397            let response = request.await?;
398            let context_id = ContextId::from_proto(response.context_id);
399            let context_proto = response.context.context("invalid context")?;
400            let context = cx.new(|cx| {
401                AssistantContext::new(
402                    context_id.clone(),
403                    replica_id,
404                    capability,
405                    language_registry,
406                    prompt_builder,
407                    slash_commands,
408                    tools,
409                    Some(project),
410                    Some(telemetry),
411                    cx,
412                )
413            })?;
414            let operations = cx
415                .background_executor()
416                .spawn(async move {
417                    context_proto
418                        .operations
419                        .into_iter()
420                        .map(ContextOperation::from_proto)
421                        .collect::<Result<Vec<_>>>()
422                })
423                .await?;
424            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
425            this.update(&mut cx, |this, cx| {
426                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
427                    existing_context
428                } else {
429                    this.register_context(&context, cx);
430                    this.synchronize_contexts(cx);
431                    context
432                }
433            })
434        })
435    }
436
437    pub fn open_local_context(
438        &mut self,
439        path: PathBuf,
440        cx: &Context<Self>,
441    ) -> Task<Result<Entity<AssistantContext>>> {
442        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
443            return Task::ready(Ok(existing_context));
444        }
445
446        let fs = self.fs.clone();
447        let languages = self.languages.clone();
448        let project = self.project.clone();
449        let telemetry = self.telemetry.clone();
450        let load = cx.background_executor().spawn({
451            let path = path.clone();
452            async move {
453                let saved_context = fs.load(&path).await?;
454                SavedContext::from_json(&saved_context)
455            }
456        });
457        let prompt_builder = self.prompt_builder.clone();
458        let slash_commands = self.slash_commands.clone();
459        let tools = self.tools.clone();
460
461        cx.spawn(|this, mut cx| async move {
462            let saved_context = load.await?;
463            let context = cx.new(|cx| {
464                AssistantContext::deserialize(
465                    saved_context,
466                    path.clone(),
467                    languages,
468                    prompt_builder,
469                    slash_commands,
470                    tools,
471                    Some(project),
472                    Some(telemetry),
473                    cx,
474                )
475            })?;
476            this.update(&mut cx, |this, cx| {
477                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
478                    existing_context
479                } else {
480                    this.register_context(&context, cx);
481                    context
482                }
483            })
484        })
485    }
486
487    fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
488        self.contexts.iter().find_map(|context| {
489            let context = context.upgrade()?;
490            if context.read(cx).path() == Some(path) {
491                Some(context)
492            } else {
493                None
494            }
495        })
496    }
497
498    pub fn loaded_context_for_id(
499        &self,
500        id: &ContextId,
501        cx: &App,
502    ) -> Option<Entity<AssistantContext>> {
503        self.contexts.iter().find_map(|context| {
504            let context = context.upgrade()?;
505            if context.read(cx).id() == id {
506                Some(context)
507            } else {
508                None
509            }
510        })
511    }
512
513    pub fn open_remote_context(
514        &mut self,
515        context_id: ContextId,
516        cx: &mut Context<Self>,
517    ) -> Task<Result<Entity<AssistantContext>>> {
518        let project = self.project.read(cx);
519        let Some(project_id) = project.remote_id() else {
520            return Task::ready(Err(anyhow!("project was not remote")));
521        };
522
523        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
524            return Task::ready(Ok(context));
525        }
526
527        let replica_id = project.replica_id();
528        let capability = project.capability();
529        let language_registry = self.languages.clone();
530        let project = self.project.clone();
531        let telemetry = self.telemetry.clone();
532        let request = self.client.request(proto::OpenContext {
533            project_id,
534            context_id: context_id.to_proto(),
535        });
536        let prompt_builder = self.prompt_builder.clone();
537        let slash_commands = self.slash_commands.clone();
538        let tools = self.tools.clone();
539        cx.spawn(|this, mut cx| async move {
540            let response = request.await?;
541            let context_proto = response.context.context("invalid context")?;
542            let context = cx.new(|cx| {
543                AssistantContext::new(
544                    context_id.clone(),
545                    replica_id,
546                    capability,
547                    language_registry,
548                    prompt_builder,
549                    slash_commands,
550                    tools,
551                    Some(project),
552                    Some(telemetry),
553                    cx,
554                )
555            })?;
556            let operations = cx
557                .background_executor()
558                .spawn(async move {
559                    context_proto
560                        .operations
561                        .into_iter()
562                        .map(ContextOperation::from_proto)
563                        .collect::<Result<Vec<_>>>()
564                })
565                .await?;
566            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
567            this.update(&mut cx, |this, cx| {
568                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
569                    existing_context
570                } else {
571                    this.register_context(&context, cx);
572                    this.synchronize_contexts(cx);
573                    context
574                }
575            })
576        })
577    }
578
579    fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
580        let handle = if self.project_is_shared {
581            ContextHandle::Strong(context.clone())
582        } else {
583            ContextHandle::Weak(context.downgrade())
584        };
585        self.contexts.push(handle);
586        self.advertise_contexts(cx);
587        cx.subscribe(context, Self::handle_context_event).detach();
588    }
589
590    fn handle_context_event(
591        &mut self,
592        context: Entity<AssistantContext>,
593        event: &ContextEvent,
594        cx: &mut Context<Self>,
595    ) {
596        let Some(project_id) = self.project.read(cx).remote_id() else {
597            return;
598        };
599
600        match event {
601            ContextEvent::SummaryChanged => {
602                self.advertise_contexts(cx);
603            }
604            ContextEvent::Operation(operation) => {
605                let context_id = context.read(cx).id().to_proto();
606                let operation = operation.to_proto();
607                self.client
608                    .send(proto::UpdateContext {
609                        project_id,
610                        context_id,
611                        operation: Some(operation),
612                    })
613                    .log_err();
614            }
615            _ => {}
616        }
617    }
618
619    fn advertise_contexts(&self, cx: &App) {
620        let Some(project_id) = self.project.read(cx).remote_id() else {
621            return;
622        };
623
624        // For now, only the host can advertise their open contexts.
625        if self.project.read(cx).is_via_collab() {
626            return;
627        }
628
629        let contexts = self
630            .contexts
631            .iter()
632            .rev()
633            .filter_map(|context| {
634                let context = context.upgrade()?.read(cx);
635                if context.replica_id() == ReplicaId::default() {
636                    Some(proto::ContextMetadata {
637                        context_id: context.id().to_proto(),
638                        summary: context.summary().map(|summary| summary.text.clone()),
639                    })
640                } else {
641                    None
642                }
643            })
644            .collect();
645        self.client
646            .send(proto::AdvertiseContexts {
647                project_id,
648                contexts,
649            })
650            .ok();
651    }
652
653    fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
654        let Some(project_id) = self.project.read(cx).remote_id() else {
655            return;
656        };
657
658        let contexts = self
659            .contexts
660            .iter()
661            .filter_map(|context| {
662                let context = context.upgrade()?.read(cx);
663                if context.replica_id() != ReplicaId::default() {
664                    Some(context.version(cx).to_proto(context.id().clone()))
665                } else {
666                    None
667                }
668            })
669            .collect();
670
671        let client = self.client.clone();
672        let request = self.client.request(proto::SynchronizeContexts {
673            project_id,
674            contexts,
675        });
676        cx.spawn(|this, cx| async move {
677            let response = request.await?;
678
679            let mut context_ids = Vec::new();
680            let mut operations = Vec::new();
681            this.read_with(&cx, |this, cx| {
682                for context_version_proto in response.contexts {
683                    let context_version = ContextVersion::from_proto(&context_version_proto);
684                    let context_id = ContextId::from_proto(context_version_proto.context_id);
685                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
686                        context_ids.push(context_id);
687                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
688                    }
689                }
690            })?;
691
692            let operations = futures::future::join_all(operations).await;
693            for (context_id, operations) in context_ids.into_iter().zip(operations) {
694                for operation in operations {
695                    client.send(proto::UpdateContext {
696                        project_id,
697                        context_id: context_id.to_proto(),
698                        operation: Some(operation),
699                    })?;
700                }
701            }
702
703            anyhow::Ok(())
704        })
705        .detach_and_log_err(cx);
706    }
707
708    pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
709        let metadata = self.contexts_metadata.clone();
710        let executor = cx.background_executor().clone();
711        cx.background_executor().spawn(async move {
712            if query.is_empty() {
713                metadata
714            } else {
715                let candidates = metadata
716                    .iter()
717                    .enumerate()
718                    .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
719                    .collect::<Vec<_>>();
720                let matches = fuzzy::match_strings(
721                    &candidates,
722                    &query,
723                    false,
724                    100,
725                    &Default::default(),
726                    executor,
727                )
728                .await;
729
730                matches
731                    .into_iter()
732                    .map(|mat| metadata[mat.candidate_id].clone())
733                    .collect()
734            }
735        })
736    }
737
738    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
739        &self.host_contexts
740    }
741
742    fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
743        let fs = self.fs.clone();
744        cx.spawn(|this, mut cx| async move {
745            fs.create_dir(contexts_dir()).await?;
746
747            let mut paths = fs.read_dir(contexts_dir()).await?;
748            let mut contexts = Vec::<SavedContextMetadata>::new();
749            while let Some(path) = paths.next().await {
750                let path = path?;
751                if path.extension() != Some(OsStr::new("json")) {
752                    continue;
753                }
754
755                static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
756                    LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
757
758                let metadata = fs.metadata(&path).await?;
759                if let Some((file_name, metadata)) = path
760                    .file_name()
761                    .and_then(|name| name.to_str())
762                    .zip(metadata)
763                {
764                    // This is used to filter out contexts saved by the new assistant.
765                    if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
766                        continue;
767                    }
768
769                    if let Some(title) = ASSISTANT_CONTEXT_REGEX
770                        .replace(file_name, "")
771                        .lines()
772                        .next()
773                    {
774                        contexts.push(SavedContextMetadata {
775                            title: title.to_string(),
776                            path,
777                            mtime: metadata.mtime.timestamp_for_user().into(),
778                        });
779                    }
780                }
781            }
782            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
783
784            this.update(&mut cx, |this, cx| {
785                this.contexts_metadata = contexts;
786                cx.notify();
787            })
788        })
789    }
790
791    pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
792        cx.update_entity(
793            &self.context_server_manager,
794            |context_server_manager, cx| {
795                for server in context_server_manager.servers() {
796                    context_server_manager
797                        .restart_server(&server.id(), cx)
798                        .detach_and_log_err(cx);
799                }
800            },
801        );
802    }
803
804    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
805        cx.subscribe(
806            &self.context_server_manager.clone(),
807            Self::handle_context_server_event,
808        )
809        .detach();
810    }
811
812    fn handle_context_server_event(
813        &mut self,
814        context_server_manager: Entity<ContextServerManager>,
815        event: &context_server::manager::Event,
816        cx: &mut Context<Self>,
817    ) {
818        let slash_command_working_set = self.slash_commands.clone();
819        let tool_working_set = self.tools.clone();
820        match event {
821            context_server::manager::Event::ServerStarted { server_id } => {
822                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
823                    let context_server_manager = context_server_manager.clone();
824                    cx.spawn({
825                        let server = server.clone();
826                        let server_id = server_id.clone();
827                        |this, mut cx| async move {
828                            let Some(protocol) = server.client() else {
829                                return;
830                            };
831
832                            if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
833                                if let Some(prompts) = protocol.list_prompts().await.log_err() {
834                                    let slash_command_ids = prompts
835                                        .into_iter()
836                                        .filter(assistant_slash_commands::acceptable_prompt)
837                                        .map(|prompt| {
838                                            log::info!(
839                                                "registering context server command: {:?}",
840                                                prompt.name
841                                            );
842                                            slash_command_working_set.insert(Arc::new(
843                                                assistant_slash_commands::ContextServerSlashCommand::new(
844                                                    context_server_manager.clone(),
845                                                    &server,
846                                                    prompt,
847                                                ),
848                                            ))
849                                        })
850                                        .collect::<Vec<_>>();
851
852                                    this.update(&mut cx, |this, _cx| {
853                                        this.context_server_slash_command_ids
854                                            .insert(server_id.clone(), slash_command_ids);
855                                    })
856                                    .log_err();
857                                }
858                            }
859
860                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
861                                if let Some(tools) = protocol.list_tools().await.log_err() {
862                                    let tool_ids = tools.tools.into_iter().map(|tool| {
863                                        log::info!("registering context server tool: {:?}", tool.name);
864                                        tool_working_set.insert(
865                                            Arc::new(ContextServerTool::new(
866                                                context_server_manager.clone(),
867                                                server.id(),
868                                                tool,
869                                            )),
870                                        )
871
872                                    }).collect::<Vec<_>>();
873
874
875                                    this.update(&mut cx, |this, _cx| {
876                                        this.context_server_tool_ids
877                                            .insert(server_id, tool_ids);
878                                    })
879                                    .log_err();
880                                }
881                            }
882                        }
883                    })
884                    .detach();
885                }
886            }
887            context_server::manager::Event::ServerStopped { server_id } => {
888                if let Some(slash_command_ids) =
889                    self.context_server_slash_command_ids.remove(server_id)
890                {
891                    slash_command_working_set.remove(&slash_command_ids);
892                }
893
894                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
895                    tool_working_set.remove(&tool_ids);
896                }
897            }
898        }
899    }
900}