1use crate::slash_command::context_server_command;
2use crate::{
3 prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
4 ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
5};
6use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet};
7use anyhow::{anyhow, Context as _, Result};
8use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
9use clock::ReplicaId;
10use collections::HashMap;
11use command_palette_hooks::CommandPaletteFilter;
12use context_servers::manager::{ContextServerManager, ContextServerSettings};
13use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE};
14use fs::Fs;
15use futures::StreamExt;
16use fuzzy::StringMatchCandidate;
17use gpui::{
18 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
19};
20use language::LanguageRegistry;
21use paths::contexts_dir;
22use project::Project;
23use regex::Regex;
24use rpc::AnyProtoClient;
25use settings::{Settings as _, SettingsStore};
26use std::{
27 cmp::Reverse,
28 ffi::OsStr,
29 mem,
30 path::{Path, PathBuf},
31 sync::Arc,
32 time::Duration,
33};
34use util::{ResultExt, TryFutureExt};
35
36pub fn init(client: &AnyProtoClient) {
37 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
38 client.add_model_request_handler(ContextStore::handle_open_context);
39 client.add_model_request_handler(ContextStore::handle_create_context);
40 client.add_model_message_handler(ContextStore::handle_update_context);
41 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
42}
43
44#[derive(Clone)]
45pub struct RemoteContextMetadata {
46 pub id: ContextId,
47 pub summary: Option<String>,
48}
49
50pub struct ContextStore {
51 contexts: Vec<ContextHandle>,
52 contexts_metadata: Vec<SavedContextMetadata>,
53 context_server_manager: Model<ContextServerManager>,
54 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
55 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
56 host_contexts: Vec<RemoteContextMetadata>,
57 fs: Arc<dyn Fs>,
58 languages: Arc<LanguageRegistry>,
59 slash_commands: Arc<SlashCommandWorkingSet>,
60 tools: Arc<ToolWorkingSet>,
61 telemetry: Arc<Telemetry>,
62 _watch_updates: Task<Option<()>>,
63 client: Arc<Client>,
64 project: Model<Project>,
65 project_is_shared: bool,
66 client_subscription: Option<client::Subscription>,
67 _project_subscriptions: Vec<gpui::Subscription>,
68 prompt_builder: Arc<PromptBuilder>,
69}
70
71pub enum ContextStoreEvent {
72 ContextCreated(ContextId),
73}
74
75impl EventEmitter<ContextStoreEvent> for ContextStore {}
76
77enum ContextHandle {
78 Weak(WeakModel<Context>),
79 Strong(Model<Context>),
80}
81
82impl ContextHandle {
83 fn upgrade(&self) -> Option<Model<Context>> {
84 match self {
85 ContextHandle::Weak(weak) => weak.upgrade(),
86 ContextHandle::Strong(strong) => Some(strong.clone()),
87 }
88 }
89
90 fn downgrade(&self) -> WeakModel<Context> {
91 match self {
92 ContextHandle::Weak(weak) => weak.clone(),
93 ContextHandle::Strong(strong) => strong.downgrade(),
94 }
95 }
96}
97
98impl ContextStore {
99 pub fn new(
100 project: Model<Project>,
101 prompt_builder: Arc<PromptBuilder>,
102 slash_commands: Arc<SlashCommandWorkingSet>,
103 tools: Arc<ToolWorkingSet>,
104 cx: &mut AppContext,
105 ) -> Task<Result<Model<Self>>> {
106 let fs = project.read(cx).fs().clone();
107 let languages = project.read(cx).languages().clone();
108 let telemetry = project.read(cx).client().telemetry().clone();
109 cx.spawn(|mut cx| async move {
110 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
111 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
112
113 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
114 let context_server_manager = cx.new_model(|_cx| ContextServerManager::new());
115 let mut this = Self {
116 contexts: Vec::new(),
117 contexts_metadata: Vec::new(),
118 context_server_manager,
119 context_server_slash_command_ids: HashMap::default(),
120 context_server_tool_ids: HashMap::default(),
121 host_contexts: Vec::new(),
122 fs,
123 languages,
124 slash_commands,
125 tools,
126 telemetry,
127 _watch_updates: cx.spawn(|this, mut cx| {
128 async move {
129 while events.next().await.is_some() {
130 this.update(&mut cx, |this, cx| this.reload(cx))?
131 .await
132 .log_err();
133 }
134 anyhow::Ok(())
135 }
136 .log_err()
137 }),
138 client_subscription: None,
139 _project_subscriptions: vec![
140 cx.observe(&project, Self::handle_project_changed),
141 cx.subscribe(&project, Self::handle_project_event),
142 ],
143 project_is_shared: false,
144 client: project.read(cx).client(),
145 project: project.clone(),
146 prompt_builder,
147 };
148 this.handle_project_changed(project.clone(), cx);
149 this.synchronize_contexts(cx);
150 this.register_context_server_handlers(cx);
151
152 if project.read(cx).is_local() {
153 // TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
154 // In order to register the context servers when the extension is loaded, we're periodically looping to
155 // see if there are context servers to register.
156 //
157 // I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
158 //
159 // We should find a more elegant way to do this.
160 let context_server_factory_registry =
161 ContextServerFactoryRegistry::default_global(cx);
162 cx.spawn(|context_store, mut cx| async move {
163 loop {
164 let mut servers_to_register = Vec::new();
165 for (_id, factory) in
166 context_server_factory_registry.context_server_factories()
167 {
168 if let Some(server) = factory(project.clone(), &cx).await.log_err()
169 {
170 servers_to_register.push(server);
171 }
172 }
173
174 let Some(_) = context_store
175 .update(&mut cx, |this, cx| {
176 this.context_server_manager.update(cx, |this, cx| {
177 for server in servers_to_register {
178 this.add_server(server, cx).detach_and_log_err(cx);
179 }
180 })
181 })
182 .log_err()
183 else {
184 break;
185 };
186
187 smol::Timer::after(Duration::from_millis(100)).await;
188 }
189
190 anyhow::Ok(())
191 })
192 .detach_and_log_err(cx);
193 }
194
195 this
196 })?;
197 this.update(&mut cx, |this, cx| this.reload(cx))?
198 .await
199 .log_err();
200
201 this.update(&mut cx, |this, cx| {
202 this.watch_context_server_settings(cx);
203 })
204 .log_err();
205
206 Ok(this)
207 })
208 }
209
210 fn watch_context_server_settings(&self, cx: &mut ModelContext<Self>) {
211 cx.observe_global::<SettingsStore>(move |this, cx| {
212 this.context_server_manager.update(cx, |manager, cx| {
213 let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
214 settings::SettingsLocation {
215 worktree_id: worktree.read(cx).id(),
216 path: Path::new(""),
217 }
218 });
219 let settings = ContextServerSettings::get(location, cx);
220
221 manager.maintain_servers(settings, cx);
222
223 let has_any_context_servers = !manager.servers().is_empty();
224 CommandPaletteFilter::update_global(cx, |filter, _cx| {
225 if has_any_context_servers {
226 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
227 } else {
228 filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
229 }
230 });
231 })
232 })
233 .detach();
234 }
235
236 async fn handle_advertise_contexts(
237 this: Model<Self>,
238 envelope: TypedEnvelope<proto::AdvertiseContexts>,
239 mut cx: AsyncAppContext,
240 ) -> Result<()> {
241 this.update(&mut cx, |this, cx| {
242 this.host_contexts = envelope
243 .payload
244 .contexts
245 .into_iter()
246 .map(|context| RemoteContextMetadata {
247 id: ContextId::from_proto(context.context_id),
248 summary: context.summary,
249 })
250 .collect();
251 cx.notify();
252 })
253 }
254
255 async fn handle_open_context(
256 this: Model<Self>,
257 envelope: TypedEnvelope<proto::OpenContext>,
258 mut cx: AsyncAppContext,
259 ) -> Result<proto::OpenContextResponse> {
260 let context_id = ContextId::from_proto(envelope.payload.context_id);
261 let operations = this.update(&mut cx, |this, cx| {
262 if this.project.read(cx).is_via_collab() {
263 return Err(anyhow!("only the host contexts can be opened"));
264 }
265
266 let context = this
267 .loaded_context_for_id(&context_id, cx)
268 .context("context not found")?;
269 if context.read(cx).replica_id() != ReplicaId::default() {
270 return Err(anyhow!("context must be opened via the host"));
271 }
272
273 anyhow::Ok(
274 context
275 .read(cx)
276 .serialize_ops(&ContextVersion::default(), cx),
277 )
278 })??;
279 let operations = operations.await;
280 Ok(proto::OpenContextResponse {
281 context: Some(proto::Context { operations }),
282 })
283 }
284
285 async fn handle_create_context(
286 this: Model<Self>,
287 _: TypedEnvelope<proto::CreateContext>,
288 mut cx: AsyncAppContext,
289 ) -> Result<proto::CreateContextResponse> {
290 let (context_id, operations) = this.update(&mut cx, |this, cx| {
291 if this.project.read(cx).is_via_collab() {
292 return Err(anyhow!("can only create contexts as the host"));
293 }
294
295 let context = this.create(cx);
296 let context_id = context.read(cx).id().clone();
297 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
298
299 anyhow::Ok((
300 context_id,
301 context
302 .read(cx)
303 .serialize_ops(&ContextVersion::default(), cx),
304 ))
305 })??;
306 let operations = operations.await;
307 Ok(proto::CreateContextResponse {
308 context_id: context_id.to_proto(),
309 context: Some(proto::Context { operations }),
310 })
311 }
312
313 async fn handle_update_context(
314 this: Model<Self>,
315 envelope: TypedEnvelope<proto::UpdateContext>,
316 mut cx: AsyncAppContext,
317 ) -> Result<()> {
318 this.update(&mut cx, |this, cx| {
319 let context_id = ContextId::from_proto(envelope.payload.context_id);
320 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
321 let operation_proto = envelope.payload.operation.context("invalid operation")?;
322 let operation = ContextOperation::from_proto(operation_proto)?;
323 context.update(cx, |context, cx| context.apply_ops([operation], cx));
324 }
325 Ok(())
326 })?
327 }
328
329 async fn handle_synchronize_contexts(
330 this: Model<Self>,
331 envelope: TypedEnvelope<proto::SynchronizeContexts>,
332 mut cx: AsyncAppContext,
333 ) -> Result<proto::SynchronizeContextsResponse> {
334 this.update(&mut cx, |this, cx| {
335 if this.project.read(cx).is_via_collab() {
336 return Err(anyhow!("only the host can synchronize contexts"));
337 }
338
339 let mut local_versions = Vec::new();
340 for remote_version_proto in envelope.payload.contexts {
341 let remote_version = ContextVersion::from_proto(&remote_version_proto);
342 let context_id = ContextId::from_proto(remote_version_proto.context_id);
343 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
344 let context = context.read(cx);
345 let operations = context.serialize_ops(&remote_version, cx);
346 local_versions.push(context.version(cx).to_proto(context_id.clone()));
347 let client = this.client.clone();
348 let project_id = envelope.payload.project_id;
349 cx.background_executor()
350 .spawn(async move {
351 let operations = operations.await;
352 for operation in operations {
353 client.send(proto::UpdateContext {
354 project_id,
355 context_id: context_id.to_proto(),
356 operation: Some(operation),
357 })?;
358 }
359 anyhow::Ok(())
360 })
361 .detach_and_log_err(cx);
362 }
363 }
364
365 this.advertise_contexts(cx);
366
367 anyhow::Ok(proto::SynchronizeContextsResponse {
368 contexts: local_versions,
369 })
370 })?
371 }
372
373 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
374 let is_shared = self.project.read(cx).is_shared();
375 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
376 if is_shared == was_shared {
377 return;
378 }
379
380 if is_shared {
381 self.contexts.retain_mut(|context| {
382 if let Some(strong_context) = context.upgrade() {
383 *context = ContextHandle::Strong(strong_context);
384 true
385 } else {
386 false
387 }
388 });
389 let remote_id = self.project.read(cx).remote_id().unwrap();
390 self.client_subscription = self
391 .client
392 .subscribe_to_entity(remote_id)
393 .log_err()
394 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
395 self.advertise_contexts(cx);
396 } else {
397 self.client_subscription = None;
398 }
399 }
400
401 fn handle_project_event(
402 &mut self,
403 _: Model<Project>,
404 event: &project::Event,
405 cx: &mut ModelContext<Self>,
406 ) {
407 match event {
408 project::Event::Reshared => {
409 self.advertise_contexts(cx);
410 }
411 project::Event::HostReshared | project::Event::Rejoined => {
412 self.synchronize_contexts(cx);
413 }
414 project::Event::DisconnectedFromHost => {
415 self.contexts.retain_mut(|context| {
416 if let Some(strong_context) = context.upgrade() {
417 *context = ContextHandle::Weak(context.downgrade());
418 strong_context.update(cx, |context, cx| {
419 if context.replica_id() != ReplicaId::default() {
420 context.set_capability(language::Capability::ReadOnly, cx);
421 }
422 });
423 true
424 } else {
425 false
426 }
427 });
428 self.host_contexts.clear();
429 cx.notify();
430 }
431 _ => {}
432 }
433 }
434
435 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
436 let context = cx.new_model(|cx| {
437 Context::local(
438 self.languages.clone(),
439 Some(self.project.clone()),
440 Some(self.telemetry.clone()),
441 self.prompt_builder.clone(),
442 self.slash_commands.clone(),
443 self.tools.clone(),
444 cx,
445 )
446 });
447 self.register_context(&context, cx);
448 context
449 }
450
451 pub fn create_remote_context(
452 &mut self,
453 cx: &mut ModelContext<Self>,
454 ) -> Task<Result<Model<Context>>> {
455 let project = self.project.read(cx);
456 let Some(project_id) = project.remote_id() else {
457 return Task::ready(Err(anyhow!("project was not remote")));
458 };
459
460 let replica_id = project.replica_id();
461 let capability = project.capability();
462 let language_registry = self.languages.clone();
463 let project = self.project.clone();
464 let telemetry = self.telemetry.clone();
465 let prompt_builder = self.prompt_builder.clone();
466 let slash_commands = self.slash_commands.clone();
467 let tools = self.tools.clone();
468 let request = self.client.request(proto::CreateContext { project_id });
469 cx.spawn(|this, mut cx| async move {
470 let response = request.await?;
471 let context_id = ContextId::from_proto(response.context_id);
472 let context_proto = response.context.context("invalid context")?;
473 let context = cx.new_model(|cx| {
474 Context::new(
475 context_id.clone(),
476 replica_id,
477 capability,
478 language_registry,
479 prompt_builder,
480 slash_commands,
481 tools,
482 Some(project),
483 Some(telemetry),
484 cx,
485 )
486 })?;
487 let operations = cx
488 .background_executor()
489 .spawn(async move {
490 context_proto
491 .operations
492 .into_iter()
493 .map(ContextOperation::from_proto)
494 .collect::<Result<Vec<_>>>()
495 })
496 .await?;
497 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
498 this.update(&mut cx, |this, cx| {
499 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
500 existing_context
501 } else {
502 this.register_context(&context, cx);
503 this.synchronize_contexts(cx);
504 context
505 }
506 })
507 })
508 }
509
510 pub fn open_local_context(
511 &mut self,
512 path: PathBuf,
513 cx: &ModelContext<Self>,
514 ) -> Task<Result<Model<Context>>> {
515 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
516 return Task::ready(Ok(existing_context));
517 }
518
519 let fs = self.fs.clone();
520 let languages = self.languages.clone();
521 let project = self.project.clone();
522 let telemetry = self.telemetry.clone();
523 let load = cx.background_executor().spawn({
524 let path = path.clone();
525 async move {
526 let saved_context = fs.load(&path).await?;
527 SavedContext::from_json(&saved_context)
528 }
529 });
530 let prompt_builder = self.prompt_builder.clone();
531 let slash_commands = self.slash_commands.clone();
532 let tools = self.tools.clone();
533
534 cx.spawn(|this, mut cx| async move {
535 let saved_context = load.await?;
536 let context = cx.new_model(|cx| {
537 Context::deserialize(
538 saved_context,
539 path.clone(),
540 languages,
541 prompt_builder,
542 slash_commands,
543 tools,
544 Some(project),
545 Some(telemetry),
546 cx,
547 )
548 })?;
549 this.update(&mut cx, |this, cx| {
550 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
551 existing_context
552 } else {
553 this.register_context(&context, cx);
554 context
555 }
556 })
557 })
558 }
559
560 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
561 self.contexts.iter().find_map(|context| {
562 let context = context.upgrade()?;
563 if context.read(cx).path() == Some(path) {
564 Some(context)
565 } else {
566 None
567 }
568 })
569 }
570
571 pub(super) fn loaded_context_for_id(
572 &self,
573 id: &ContextId,
574 cx: &AppContext,
575 ) -> Option<Model<Context>> {
576 self.contexts.iter().find_map(|context| {
577 let context = context.upgrade()?;
578 if context.read(cx).id() == id {
579 Some(context)
580 } else {
581 None
582 }
583 })
584 }
585
586 pub fn open_remote_context(
587 &mut self,
588 context_id: ContextId,
589 cx: &mut ModelContext<Self>,
590 ) -> Task<Result<Model<Context>>> {
591 let project = self.project.read(cx);
592 let Some(project_id) = project.remote_id() else {
593 return Task::ready(Err(anyhow!("project was not remote")));
594 };
595
596 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
597 return Task::ready(Ok(context));
598 }
599
600 let replica_id = project.replica_id();
601 let capability = project.capability();
602 let language_registry = self.languages.clone();
603 let project = self.project.clone();
604 let telemetry = self.telemetry.clone();
605 let request = self.client.request(proto::OpenContext {
606 project_id,
607 context_id: context_id.to_proto(),
608 });
609 let prompt_builder = self.prompt_builder.clone();
610 let slash_commands = self.slash_commands.clone();
611 let tools = self.tools.clone();
612 cx.spawn(|this, mut cx| async move {
613 let response = request.await?;
614 let context_proto = response.context.context("invalid context")?;
615 let context = cx.new_model(|cx| {
616 Context::new(
617 context_id.clone(),
618 replica_id,
619 capability,
620 language_registry,
621 prompt_builder,
622 slash_commands,
623 tools,
624 Some(project),
625 Some(telemetry),
626 cx,
627 )
628 })?;
629 let operations = cx
630 .background_executor()
631 .spawn(async move {
632 context_proto
633 .operations
634 .into_iter()
635 .map(ContextOperation::from_proto)
636 .collect::<Result<Vec<_>>>()
637 })
638 .await?;
639 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
640 this.update(&mut cx, |this, cx| {
641 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
642 existing_context
643 } else {
644 this.register_context(&context, cx);
645 this.synchronize_contexts(cx);
646 context
647 }
648 })
649 })
650 }
651
652 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
653 let handle = if self.project_is_shared {
654 ContextHandle::Strong(context.clone())
655 } else {
656 ContextHandle::Weak(context.downgrade())
657 };
658 self.contexts.push(handle);
659 self.advertise_contexts(cx);
660 cx.subscribe(context, Self::handle_context_event).detach();
661 }
662
663 fn handle_context_event(
664 &mut self,
665 context: Model<Context>,
666 event: &ContextEvent,
667 cx: &mut ModelContext<Self>,
668 ) {
669 let Some(project_id) = self.project.read(cx).remote_id() else {
670 return;
671 };
672
673 match event {
674 ContextEvent::SummaryChanged => {
675 self.advertise_contexts(cx);
676 }
677 ContextEvent::Operation(operation) => {
678 let context_id = context.read(cx).id().to_proto();
679 let operation = operation.to_proto();
680 self.client
681 .send(proto::UpdateContext {
682 project_id,
683 context_id,
684 operation: Some(operation),
685 })
686 .log_err();
687 }
688 _ => {}
689 }
690 }
691
692 fn advertise_contexts(&self, cx: &AppContext) {
693 let Some(project_id) = self.project.read(cx).remote_id() else {
694 return;
695 };
696
697 // For now, only the host can advertise their open contexts.
698 if self.project.read(cx).is_via_collab() {
699 return;
700 }
701
702 let contexts = self
703 .contexts
704 .iter()
705 .rev()
706 .filter_map(|context| {
707 let context = context.upgrade()?.read(cx);
708 if context.replica_id() == ReplicaId::default() {
709 Some(proto::ContextMetadata {
710 context_id: context.id().to_proto(),
711 summary: context.summary().map(|summary| summary.text.clone()),
712 })
713 } else {
714 None
715 }
716 })
717 .collect();
718 self.client
719 .send(proto::AdvertiseContexts {
720 project_id,
721 contexts,
722 })
723 .ok();
724 }
725
726 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
727 let Some(project_id) = self.project.read(cx).remote_id() else {
728 return;
729 };
730
731 let contexts = self
732 .contexts
733 .iter()
734 .filter_map(|context| {
735 let context = context.upgrade()?.read(cx);
736 if context.replica_id() != ReplicaId::default() {
737 Some(context.version(cx).to_proto(context.id().clone()))
738 } else {
739 None
740 }
741 })
742 .collect();
743
744 let client = self.client.clone();
745 let request = self.client.request(proto::SynchronizeContexts {
746 project_id,
747 contexts,
748 });
749 cx.spawn(|this, cx| async move {
750 let response = request.await?;
751
752 let mut context_ids = Vec::new();
753 let mut operations = Vec::new();
754 this.read_with(&cx, |this, cx| {
755 for context_version_proto in response.contexts {
756 let context_version = ContextVersion::from_proto(&context_version_proto);
757 let context_id = ContextId::from_proto(context_version_proto.context_id);
758 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
759 context_ids.push(context_id);
760 operations.push(context.read(cx).serialize_ops(&context_version, cx));
761 }
762 }
763 })?;
764
765 let operations = futures::future::join_all(operations).await;
766 for (context_id, operations) in context_ids.into_iter().zip(operations) {
767 for operation in operations {
768 client.send(proto::UpdateContext {
769 project_id,
770 context_id: context_id.to_proto(),
771 operation: Some(operation),
772 })?;
773 }
774 }
775
776 anyhow::Ok(())
777 })
778 .detach_and_log_err(cx);
779 }
780
781 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
782 let metadata = self.contexts_metadata.clone();
783 let executor = cx.background_executor().clone();
784 cx.background_executor().spawn(async move {
785 if query.is_empty() {
786 metadata
787 } else {
788 let candidates = metadata
789 .iter()
790 .enumerate()
791 .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone()))
792 .collect::<Vec<_>>();
793 let matches = fuzzy::match_strings(
794 &candidates,
795 &query,
796 false,
797 100,
798 &Default::default(),
799 executor,
800 )
801 .await;
802
803 matches
804 .into_iter()
805 .map(|mat| metadata[mat.candidate_id].clone())
806 .collect()
807 }
808 })
809 }
810
811 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
812 &self.host_contexts
813 }
814
815 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
816 let fs = self.fs.clone();
817 cx.spawn(|this, mut cx| async move {
818 fs.create_dir(contexts_dir()).await?;
819
820 let mut paths = fs.read_dir(contexts_dir()).await?;
821 let mut contexts = Vec::<SavedContextMetadata>::new();
822 while let Some(path) = paths.next().await {
823 let path = path?;
824 if path.extension() != Some(OsStr::new("json")) {
825 continue;
826 }
827
828 let pattern = r" - \d+.zed.json$";
829 let re = Regex::new(pattern).unwrap();
830
831 let metadata = fs.metadata(&path).await?;
832 if let Some((file_name, metadata)) = path
833 .file_name()
834 .and_then(|name| name.to_str())
835 .zip(metadata)
836 {
837 // This is used to filter out contexts saved by the new assistant.
838 if !re.is_match(file_name) {
839 continue;
840 }
841
842 if let Some(title) = re.replace(file_name, "").lines().next() {
843 contexts.push(SavedContextMetadata {
844 title: title.to_string(),
845 path,
846 mtime: metadata.mtime.into(),
847 });
848 }
849 }
850 }
851 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
852
853 this.update(&mut cx, |this, cx| {
854 this.contexts_metadata = contexts;
855 cx.notify();
856 })
857 })
858 }
859
860 pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
861 cx.update_model(
862 &self.context_server_manager,
863 |context_server_manager, cx| {
864 for server in context_server_manager.servers() {
865 context_server_manager
866 .restart_server(&server.id(), cx)
867 .detach_and_log_err(cx);
868 }
869 },
870 );
871 }
872
873 fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
874 cx.subscribe(
875 &self.context_server_manager.clone(),
876 Self::handle_context_server_event,
877 )
878 .detach();
879 }
880
881 fn handle_context_server_event(
882 &mut self,
883 context_server_manager: Model<ContextServerManager>,
884 event: &context_servers::manager::Event,
885 cx: &mut ModelContext<Self>,
886 ) {
887 let slash_command_working_set = self.slash_commands.clone();
888 let tool_working_set = self.tools.clone();
889 match event {
890 context_servers::manager::Event::ServerStarted { server_id } => {
891 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
892 let context_server_manager = context_server_manager.clone();
893 cx.spawn({
894 let server = server.clone();
895 let server_id = server_id.clone();
896 |this, mut cx| async move {
897 let Some(protocol) = server.client() else {
898 return;
899 };
900
901 if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
902 if let Some(prompts) = protocol.list_prompts().await.log_err() {
903 let slash_command_ids = prompts
904 .into_iter()
905 .filter(context_server_command::acceptable_prompt)
906 .map(|prompt| {
907 log::info!(
908 "registering context server command: {:?}",
909 prompt.name
910 );
911 slash_command_working_set.insert(Arc::new(
912 context_server_command::ContextServerSlashCommand::new(
913 context_server_manager.clone(),
914 &server,
915 prompt,
916 ),
917 ))
918 })
919 .collect::<Vec<_>>();
920
921 this.update(&mut cx, |this, _cx| {
922 this.context_server_slash_command_ids
923 .insert(server_id.clone(), slash_command_ids);
924 })
925 .log_err();
926 }
927 }
928
929 if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
930 if let Some(tools) = protocol.list_tools().await.log_err() {
931 let tool_ids = tools.tools.into_iter().map(|tool| {
932 log::info!("registering context server tool: {:?}", tool.name);
933 tool_working_set.insert(
934 Arc::new(tools::context_server_tool::ContextServerTool::new(
935 context_server_manager.clone(),
936 server.id(),
937 tool,
938 )),
939 )
940
941 }).collect::<Vec<_>>();
942
943
944 this.update(&mut cx, |this, _cx| {
945 this.context_server_tool_ids
946 .insert(server_id, tool_ids);
947 })
948 .log_err();
949 }
950 }
951 }
952 })
953 .detach();
954 }
955 }
956 context_servers::manager::Event::ServerStopped { server_id } => {
957 if let Some(slash_command_ids) =
958 self.context_server_slash_command_ids.remove(server_id)
959 {
960 slash_command_working_set.remove(&slash_command_ids);
961 }
962
963 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
964 tool_working_set.remove(&tool_ids);
965 }
966 }
967 }
968 }
969}