context_store.rs

  1use crate::slash_command::context_server_command;
  2use crate::SlashCommandId;
  3use crate::{
  4    prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
  5    ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use assistant_tool::{ToolId, ToolWorkingSet};
  9use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
 10use clock::ReplicaId;
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::StreamExt;
 16use fuzzy::StringMatchCandidate;
 17use gpui::{
 18    AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
 19};
 20use language::LanguageRegistry;
 21use paths::contexts_dir;
 22use project::Project;
 23use regex::Regex;
 24use rpc::AnyProtoClient;
 25use std::sync::LazyLock;
 26use std::{
 27    cmp::Reverse,
 28    ffi::OsStr,
 29    mem,
 30    path::{Path, PathBuf},
 31    sync::Arc,
 32    time::Duration,
 33};
 34use util::{ResultExt, TryFutureExt};
 35
 36pub fn init(client: &AnyProtoClient) {
 37    client.add_model_message_handler(ContextStore::handle_advertise_contexts);
 38    client.add_model_request_handler(ContextStore::handle_open_context);
 39    client.add_model_request_handler(ContextStore::handle_create_context);
 40    client.add_model_message_handler(ContextStore::handle_update_context);
 41    client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
 42}
 43
 44#[derive(Clone)]
 45pub struct RemoteContextMetadata {
 46    pub id: ContextId,
 47    pub summary: Option<String>,
 48}
 49
 50pub struct ContextStore {
 51    contexts: Vec<ContextHandle>,
 52    contexts_metadata: Vec<SavedContextMetadata>,
 53    context_server_manager: Model<ContextServerManager>,
 54    context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
 55    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 56    host_contexts: Vec<RemoteContextMetadata>,
 57    fs: Arc<dyn Fs>,
 58    languages: Arc<LanguageRegistry>,
 59    slash_commands: Arc<SlashCommandWorkingSet>,
 60    tools: Arc<ToolWorkingSet>,
 61    telemetry: Arc<Telemetry>,
 62    _watch_updates: Task<Option<()>>,
 63    client: Arc<Client>,
 64    project: Model<Project>,
 65    project_is_shared: bool,
 66    client_subscription: Option<client::Subscription>,
 67    _project_subscriptions: Vec<gpui::Subscription>,
 68    prompt_builder: Arc<PromptBuilder>,
 69}
 70
 71pub enum ContextStoreEvent {
 72    ContextCreated(ContextId),
 73}
 74
 75impl EventEmitter<ContextStoreEvent> for ContextStore {}
 76
 77enum ContextHandle {
 78    Weak(WeakModel<Context>),
 79    Strong(Model<Context>),
 80}
 81
 82impl ContextHandle {
 83    fn upgrade(&self) -> Option<Model<Context>> {
 84        match self {
 85            ContextHandle::Weak(weak) => weak.upgrade(),
 86            ContextHandle::Strong(strong) => Some(strong.clone()),
 87        }
 88    }
 89
 90    fn downgrade(&self) -> WeakModel<Context> {
 91        match self {
 92            ContextHandle::Weak(weak) => weak.clone(),
 93            ContextHandle::Strong(strong) => strong.downgrade(),
 94        }
 95    }
 96}
 97
 98impl ContextStore {
 99    pub fn new(
100        project: Model<Project>,
101        prompt_builder: Arc<PromptBuilder>,
102        slash_commands: Arc<SlashCommandWorkingSet>,
103        tools: Arc<ToolWorkingSet>,
104        cx: &mut AppContext,
105    ) -> Task<Result<Model<Self>>> {
106        let fs = project.read(cx).fs().clone();
107        let languages = project.read(cx).languages().clone();
108        let telemetry = project.read(cx).client().telemetry().clone();
109        cx.spawn(|mut cx| async move {
110            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
111            let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
112
113            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
114                let context_server_factory_registry =
115                    ContextServerFactoryRegistry::default_global(cx);
116                let context_server_manager = cx.new_model(|cx| {
117                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
118                });
119                let mut this = Self {
120                    contexts: Vec::new(),
121                    contexts_metadata: Vec::new(),
122                    context_server_manager,
123                    context_server_slash_command_ids: HashMap::default(),
124                    context_server_tool_ids: HashMap::default(),
125                    host_contexts: Vec::new(),
126                    fs,
127                    languages,
128                    slash_commands,
129                    tools,
130                    telemetry,
131                    _watch_updates: cx.spawn(|this, mut cx| {
132                        async move {
133                            while events.next().await.is_some() {
134                                this.update(&mut cx, |this, cx| this.reload(cx))?
135                                    .await
136                                    .log_err();
137                            }
138                            anyhow::Ok(())
139                        }
140                        .log_err()
141                    }),
142                    client_subscription: None,
143                    _project_subscriptions: vec![
144                        cx.observe(&project, Self::handle_project_changed),
145                        cx.subscribe(&project, Self::handle_project_event),
146                    ],
147                    project_is_shared: false,
148                    client: project.read(cx).client(),
149                    project: project.clone(),
150                    prompt_builder,
151                };
152                this.handle_project_changed(project.clone(), cx);
153                this.synchronize_contexts(cx);
154                this.register_context_server_handlers(cx);
155                this
156            })?;
157            this.update(&mut cx, |this, cx| this.reload(cx))?
158                .await
159                .log_err();
160
161            Ok(this)
162        })
163    }
164
165    async fn handle_advertise_contexts(
166        this: Model<Self>,
167        envelope: TypedEnvelope<proto::AdvertiseContexts>,
168        mut cx: AsyncAppContext,
169    ) -> Result<()> {
170        this.update(&mut cx, |this, cx| {
171            this.host_contexts = envelope
172                .payload
173                .contexts
174                .into_iter()
175                .map(|context| RemoteContextMetadata {
176                    id: ContextId::from_proto(context.context_id),
177                    summary: context.summary,
178                })
179                .collect();
180            cx.notify();
181        })
182    }
183
184    async fn handle_open_context(
185        this: Model<Self>,
186        envelope: TypedEnvelope<proto::OpenContext>,
187        mut cx: AsyncAppContext,
188    ) -> Result<proto::OpenContextResponse> {
189        let context_id = ContextId::from_proto(envelope.payload.context_id);
190        let operations = this.update(&mut cx, |this, cx| {
191            if this.project.read(cx).is_via_collab() {
192                return Err(anyhow!("only the host contexts can be opened"));
193            }
194
195            let context = this
196                .loaded_context_for_id(&context_id, cx)
197                .context("context not found")?;
198            if context.read(cx).replica_id() != ReplicaId::default() {
199                return Err(anyhow!("context must be opened via the host"));
200            }
201
202            anyhow::Ok(
203                context
204                    .read(cx)
205                    .serialize_ops(&ContextVersion::default(), cx),
206            )
207        })??;
208        let operations = operations.await;
209        Ok(proto::OpenContextResponse {
210            context: Some(proto::Context { operations }),
211        })
212    }
213
214    async fn handle_create_context(
215        this: Model<Self>,
216        _: TypedEnvelope<proto::CreateContext>,
217        mut cx: AsyncAppContext,
218    ) -> Result<proto::CreateContextResponse> {
219        let (context_id, operations) = this.update(&mut cx, |this, cx| {
220            if this.project.read(cx).is_via_collab() {
221                return Err(anyhow!("can only create contexts as the host"));
222            }
223
224            let context = this.create(cx);
225            let context_id = context.read(cx).id().clone();
226            cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
227
228            anyhow::Ok((
229                context_id,
230                context
231                    .read(cx)
232                    .serialize_ops(&ContextVersion::default(), cx),
233            ))
234        })??;
235        let operations = operations.await;
236        Ok(proto::CreateContextResponse {
237            context_id: context_id.to_proto(),
238            context: Some(proto::Context { operations }),
239        })
240    }
241
242    async fn handle_update_context(
243        this: Model<Self>,
244        envelope: TypedEnvelope<proto::UpdateContext>,
245        mut cx: AsyncAppContext,
246    ) -> Result<()> {
247        this.update(&mut cx, |this, cx| {
248            let context_id = ContextId::from_proto(envelope.payload.context_id);
249            if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
250                let operation_proto = envelope.payload.operation.context("invalid operation")?;
251                let operation = ContextOperation::from_proto(operation_proto)?;
252                context.update(cx, |context, cx| context.apply_ops([operation], cx));
253            }
254            Ok(())
255        })?
256    }
257
258    async fn handle_synchronize_contexts(
259        this: Model<Self>,
260        envelope: TypedEnvelope<proto::SynchronizeContexts>,
261        mut cx: AsyncAppContext,
262    ) -> Result<proto::SynchronizeContextsResponse> {
263        this.update(&mut cx, |this, cx| {
264            if this.project.read(cx).is_via_collab() {
265                return Err(anyhow!("only the host can synchronize contexts"));
266            }
267
268            let mut local_versions = Vec::new();
269            for remote_version_proto in envelope.payload.contexts {
270                let remote_version = ContextVersion::from_proto(&remote_version_proto);
271                let context_id = ContextId::from_proto(remote_version_proto.context_id);
272                if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
273                    let context = context.read(cx);
274                    let operations = context.serialize_ops(&remote_version, cx);
275                    local_versions.push(context.version(cx).to_proto(context_id.clone()));
276                    let client = this.client.clone();
277                    let project_id = envelope.payload.project_id;
278                    cx.background_executor()
279                        .spawn(async move {
280                            let operations = operations.await;
281                            for operation in operations {
282                                client.send(proto::UpdateContext {
283                                    project_id,
284                                    context_id: context_id.to_proto(),
285                                    operation: Some(operation),
286                                })?;
287                            }
288                            anyhow::Ok(())
289                        })
290                        .detach_and_log_err(cx);
291                }
292            }
293
294            this.advertise_contexts(cx);
295
296            anyhow::Ok(proto::SynchronizeContextsResponse {
297                contexts: local_versions,
298            })
299        })?
300    }
301
302    fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
303        let is_shared = self.project.read(cx).is_shared();
304        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
305        if is_shared == was_shared {
306            return;
307        }
308
309        if is_shared {
310            self.contexts.retain_mut(|context| {
311                if let Some(strong_context) = context.upgrade() {
312                    *context = ContextHandle::Strong(strong_context);
313                    true
314                } else {
315                    false
316                }
317            });
318            let remote_id = self.project.read(cx).remote_id().unwrap();
319            self.client_subscription = self
320                .client
321                .subscribe_to_entity(remote_id)
322                .log_err()
323                .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
324            self.advertise_contexts(cx);
325        } else {
326            self.client_subscription = None;
327        }
328    }
329
330    fn handle_project_event(
331        &mut self,
332        _: Model<Project>,
333        event: &project::Event,
334        cx: &mut ModelContext<Self>,
335    ) {
336        match event {
337            project::Event::Reshared => {
338                self.advertise_contexts(cx);
339            }
340            project::Event::HostReshared | project::Event::Rejoined => {
341                self.synchronize_contexts(cx);
342            }
343            project::Event::DisconnectedFromHost => {
344                self.contexts.retain_mut(|context| {
345                    if let Some(strong_context) = context.upgrade() {
346                        *context = ContextHandle::Weak(context.downgrade());
347                        strong_context.update(cx, |context, cx| {
348                            if context.replica_id() != ReplicaId::default() {
349                                context.set_capability(language::Capability::ReadOnly, cx);
350                            }
351                        });
352                        true
353                    } else {
354                        false
355                    }
356                });
357                self.host_contexts.clear();
358                cx.notify();
359            }
360            _ => {}
361        }
362    }
363
364    pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
365        let context = cx.new_model(|cx| {
366            Context::local(
367                self.languages.clone(),
368                Some(self.project.clone()),
369                Some(self.telemetry.clone()),
370                self.prompt_builder.clone(),
371                self.slash_commands.clone(),
372                self.tools.clone(),
373                cx,
374            )
375        });
376        self.register_context(&context, cx);
377        context
378    }
379
380    pub fn create_remote_context(
381        &mut self,
382        cx: &mut ModelContext<Self>,
383    ) -> Task<Result<Model<Context>>> {
384        let project = self.project.read(cx);
385        let Some(project_id) = project.remote_id() else {
386            return Task::ready(Err(anyhow!("project was not remote")));
387        };
388
389        let replica_id = project.replica_id();
390        let capability = project.capability();
391        let language_registry = self.languages.clone();
392        let project = self.project.clone();
393        let telemetry = self.telemetry.clone();
394        let prompt_builder = self.prompt_builder.clone();
395        let slash_commands = self.slash_commands.clone();
396        let tools = self.tools.clone();
397        let request = self.client.request(proto::CreateContext { project_id });
398        cx.spawn(|this, mut cx| async move {
399            let response = request.await?;
400            let context_id = ContextId::from_proto(response.context_id);
401            let context_proto = response.context.context("invalid context")?;
402            let context = cx.new_model(|cx| {
403                Context::new(
404                    context_id.clone(),
405                    replica_id,
406                    capability,
407                    language_registry,
408                    prompt_builder,
409                    slash_commands,
410                    tools,
411                    Some(project),
412                    Some(telemetry),
413                    cx,
414                )
415            })?;
416            let operations = cx
417                .background_executor()
418                .spawn(async move {
419                    context_proto
420                        .operations
421                        .into_iter()
422                        .map(ContextOperation::from_proto)
423                        .collect::<Result<Vec<_>>>()
424                })
425                .await?;
426            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
427            this.update(&mut cx, |this, cx| {
428                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
429                    existing_context
430                } else {
431                    this.register_context(&context, cx);
432                    this.synchronize_contexts(cx);
433                    context
434                }
435            })
436        })
437    }
438
439    pub fn open_local_context(
440        &mut self,
441        path: PathBuf,
442        cx: &ModelContext<Self>,
443    ) -> Task<Result<Model<Context>>> {
444        if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
445            return Task::ready(Ok(existing_context));
446        }
447
448        let fs = self.fs.clone();
449        let languages = self.languages.clone();
450        let project = self.project.clone();
451        let telemetry = self.telemetry.clone();
452        let load = cx.background_executor().spawn({
453            let path = path.clone();
454            async move {
455                let saved_context = fs.load(&path).await?;
456                SavedContext::from_json(&saved_context)
457            }
458        });
459        let prompt_builder = self.prompt_builder.clone();
460        let slash_commands = self.slash_commands.clone();
461        let tools = self.tools.clone();
462
463        cx.spawn(|this, mut cx| async move {
464            let saved_context = load.await?;
465            let context = cx.new_model(|cx| {
466                Context::deserialize(
467                    saved_context,
468                    path.clone(),
469                    languages,
470                    prompt_builder,
471                    slash_commands,
472                    tools,
473                    Some(project),
474                    Some(telemetry),
475                    cx,
476                )
477            })?;
478            this.update(&mut cx, |this, cx| {
479                if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
480                    existing_context
481                } else {
482                    this.register_context(&context, cx);
483                    context
484                }
485            })
486        })
487    }
488
489    fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
490        self.contexts.iter().find_map(|context| {
491            let context = context.upgrade()?;
492            if context.read(cx).path() == Some(path) {
493                Some(context)
494            } else {
495                None
496            }
497        })
498    }
499
500    pub(super) fn loaded_context_for_id(
501        &self,
502        id: &ContextId,
503        cx: &AppContext,
504    ) -> Option<Model<Context>> {
505        self.contexts.iter().find_map(|context| {
506            let context = context.upgrade()?;
507            if context.read(cx).id() == id {
508                Some(context)
509            } else {
510                None
511            }
512        })
513    }
514
515    pub fn open_remote_context(
516        &mut self,
517        context_id: ContextId,
518        cx: &mut ModelContext<Self>,
519    ) -> Task<Result<Model<Context>>> {
520        let project = self.project.read(cx);
521        let Some(project_id) = project.remote_id() else {
522            return Task::ready(Err(anyhow!("project was not remote")));
523        };
524
525        if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
526            return Task::ready(Ok(context));
527        }
528
529        let replica_id = project.replica_id();
530        let capability = project.capability();
531        let language_registry = self.languages.clone();
532        let project = self.project.clone();
533        let telemetry = self.telemetry.clone();
534        let request = self.client.request(proto::OpenContext {
535            project_id,
536            context_id: context_id.to_proto(),
537        });
538        let prompt_builder = self.prompt_builder.clone();
539        let slash_commands = self.slash_commands.clone();
540        let tools = self.tools.clone();
541        cx.spawn(|this, mut cx| async move {
542            let response = request.await?;
543            let context_proto = response.context.context("invalid context")?;
544            let context = cx.new_model(|cx| {
545                Context::new(
546                    context_id.clone(),
547                    replica_id,
548                    capability,
549                    language_registry,
550                    prompt_builder,
551                    slash_commands,
552                    tools,
553                    Some(project),
554                    Some(telemetry),
555                    cx,
556                )
557            })?;
558            let operations = cx
559                .background_executor()
560                .spawn(async move {
561                    context_proto
562                        .operations
563                        .into_iter()
564                        .map(ContextOperation::from_proto)
565                        .collect::<Result<Vec<_>>>()
566                })
567                .await?;
568            context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
569            this.update(&mut cx, |this, cx| {
570                if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
571                    existing_context
572                } else {
573                    this.register_context(&context, cx);
574                    this.synchronize_contexts(cx);
575                    context
576                }
577            })
578        })
579    }
580
581    fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
582        let handle = if self.project_is_shared {
583            ContextHandle::Strong(context.clone())
584        } else {
585            ContextHandle::Weak(context.downgrade())
586        };
587        self.contexts.push(handle);
588        self.advertise_contexts(cx);
589        cx.subscribe(context, Self::handle_context_event).detach();
590    }
591
592    fn handle_context_event(
593        &mut self,
594        context: Model<Context>,
595        event: &ContextEvent,
596        cx: &mut ModelContext<Self>,
597    ) {
598        let Some(project_id) = self.project.read(cx).remote_id() else {
599            return;
600        };
601
602        match event {
603            ContextEvent::SummaryChanged => {
604                self.advertise_contexts(cx);
605            }
606            ContextEvent::Operation(operation) => {
607                let context_id = context.read(cx).id().to_proto();
608                let operation = operation.to_proto();
609                self.client
610                    .send(proto::UpdateContext {
611                        project_id,
612                        context_id,
613                        operation: Some(operation),
614                    })
615                    .log_err();
616            }
617            _ => {}
618        }
619    }
620
621    fn advertise_contexts(&self, cx: &AppContext) {
622        let Some(project_id) = self.project.read(cx).remote_id() else {
623            return;
624        };
625
626        // For now, only the host can advertise their open contexts.
627        if self.project.read(cx).is_via_collab() {
628            return;
629        }
630
631        let contexts = self
632            .contexts
633            .iter()
634            .rev()
635            .filter_map(|context| {
636                let context = context.upgrade()?.read(cx);
637                if context.replica_id() == ReplicaId::default() {
638                    Some(proto::ContextMetadata {
639                        context_id: context.id().to_proto(),
640                        summary: context.summary().map(|summary| summary.text.clone()),
641                    })
642                } else {
643                    None
644                }
645            })
646            .collect();
647        self.client
648            .send(proto::AdvertiseContexts {
649                project_id,
650                contexts,
651            })
652            .ok();
653    }
654
655    fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
656        let Some(project_id) = self.project.read(cx).remote_id() else {
657            return;
658        };
659
660        let contexts = self
661            .contexts
662            .iter()
663            .filter_map(|context| {
664                let context = context.upgrade()?.read(cx);
665                if context.replica_id() != ReplicaId::default() {
666                    Some(context.version(cx).to_proto(context.id().clone()))
667                } else {
668                    None
669                }
670            })
671            .collect();
672
673        let client = self.client.clone();
674        let request = self.client.request(proto::SynchronizeContexts {
675            project_id,
676            contexts,
677        });
678        cx.spawn(|this, cx| async move {
679            let response = request.await?;
680
681            let mut context_ids = Vec::new();
682            let mut operations = Vec::new();
683            this.read_with(&cx, |this, cx| {
684                for context_version_proto in response.contexts {
685                    let context_version = ContextVersion::from_proto(&context_version_proto);
686                    let context_id = ContextId::from_proto(context_version_proto.context_id);
687                    if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
688                        context_ids.push(context_id);
689                        operations.push(context.read(cx).serialize_ops(&context_version, cx));
690                    }
691                }
692            })?;
693
694            let operations = futures::future::join_all(operations).await;
695            for (context_id, operations) in context_ids.into_iter().zip(operations) {
696                for operation in operations {
697                    client.send(proto::UpdateContext {
698                        project_id,
699                        context_id: context_id.to_proto(),
700                        operation: Some(operation),
701                    })?;
702                }
703            }
704
705            anyhow::Ok(())
706        })
707        .detach_and_log_err(cx);
708    }
709
710    pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
711        let metadata = self.contexts_metadata.clone();
712        let executor = cx.background_executor().clone();
713        cx.background_executor().spawn(async move {
714            if query.is_empty() {
715                metadata
716            } else {
717                let candidates = metadata
718                    .iter()
719                    .enumerate()
720                    .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
721                    .collect::<Vec<_>>();
722                let matches = fuzzy::match_strings(
723                    &candidates,
724                    &query,
725                    false,
726                    100,
727                    &Default::default(),
728                    executor,
729                )
730                .await;
731
732                matches
733                    .into_iter()
734                    .map(|mat| metadata[mat.candidate_id].clone())
735                    .collect()
736            }
737        })
738    }
739
740    pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
741        &self.host_contexts
742    }
743
744    fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
745        let fs = self.fs.clone();
746        cx.spawn(|this, mut cx| async move {
747            fs.create_dir(contexts_dir()).await?;
748
749            let mut paths = fs.read_dir(contexts_dir()).await?;
750            let mut contexts = Vec::<SavedContextMetadata>::new();
751            while let Some(path) = paths.next().await {
752                let path = path?;
753                if path.extension() != Some(OsStr::new("json")) {
754                    continue;
755                }
756
757                static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
758                    LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
759
760                let metadata = fs.metadata(&path).await?;
761                if let Some((file_name, metadata)) = path
762                    .file_name()
763                    .and_then(|name| name.to_str())
764                    .zip(metadata)
765                {
766                    // This is used to filter out contexts saved by the new assistant.
767                    if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
768                        continue;
769                    }
770
771                    if let Some(title) = ASSISTANT_CONTEXT_REGEX
772                        .replace(file_name, "")
773                        .lines()
774                        .next()
775                    {
776                        contexts.push(SavedContextMetadata {
777                            title: title.to_string(),
778                            path,
779                            mtime: metadata.mtime.timestamp_for_user().into(),
780                        });
781                    }
782                }
783            }
784            contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
785
786            this.update(&mut cx, |this, cx| {
787                this.contexts_metadata = contexts;
788                cx.notify();
789            })
790        })
791    }
792
793    pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
794        cx.update_model(
795            &self.context_server_manager,
796            |context_server_manager, cx| {
797                for server in context_server_manager.servers() {
798                    context_server_manager
799                        .restart_server(&server.id(), cx)
800                        .detach_and_log_err(cx);
801                }
802            },
803        );
804    }
805
806    fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
807        cx.subscribe(
808            &self.context_server_manager.clone(),
809            Self::handle_context_server_event,
810        )
811        .detach();
812    }
813
814    fn handle_context_server_event(
815        &mut self,
816        context_server_manager: Model<ContextServerManager>,
817        event: &context_server::manager::Event,
818        cx: &mut ModelContext<Self>,
819    ) {
820        let slash_command_working_set = self.slash_commands.clone();
821        let tool_working_set = self.tools.clone();
822        match event {
823            context_server::manager::Event::ServerStarted { server_id } => {
824                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
825                    let context_server_manager = context_server_manager.clone();
826                    cx.spawn({
827                        let server = server.clone();
828                        let server_id = server_id.clone();
829                        |this, mut cx| async move {
830                            let Some(protocol) = server.client() else {
831                                return;
832                            };
833
834                            if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
835                                if let Some(prompts) = protocol.list_prompts().await.log_err() {
836                                    let slash_command_ids = prompts
837                                        .into_iter()
838                                        .filter(context_server_command::acceptable_prompt)
839                                        .map(|prompt| {
840                                            log::info!(
841                                                "registering context server command: {:?}",
842                                                prompt.name
843                                            );
844                                            slash_command_working_set.insert(Arc::new(
845                                                context_server_command::ContextServerSlashCommand::new(
846                                                    context_server_manager.clone(),
847                                                    &server,
848                                                    prompt,
849                                                ),
850                                            ))
851                                        })
852                                        .collect::<Vec<_>>();
853
854                                    this.update(&mut cx, |this, _cx| {
855                                        this.context_server_slash_command_ids
856                                            .insert(server_id.clone(), slash_command_ids);
857                                    })
858                                    .log_err();
859                                }
860                            }
861
862                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
863                                if let Some(tools) = protocol.list_tools().await.log_err() {
864                                    let tool_ids = tools.tools.into_iter().map(|tool| {
865                                        log::info!("registering context server tool: {:?}", tool.name);
866                                        tool_working_set.insert(
867                                            Arc::new(ContextServerTool::new(
868                                                context_server_manager.clone(),
869                                                server.id(),
870                                                tool,
871                                            )),
872                                        )
873
874                                    }).collect::<Vec<_>>();
875
876
877                                    this.update(&mut cx, |this, _cx| {
878                                        this.context_server_tool_ids
879                                            .insert(server_id, tool_ids);
880                                    })
881                                    .log_err();
882                                }
883                            }
884                        }
885                    })
886                    .detach();
887                }
888            }
889            context_server::manager::Event::ServerStopped { server_id } => {
890                if let Some(slash_command_ids) =
891                    self.context_server_slash_command_ids.remove(server_id)
892                {
893                    slash_command_working_set.remove(&slash_command_ids);
894                }
895
896                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
897                    tool_working_set.remove(&tool_ids);
898                }
899            }
900        }
901    }
902}