1use crate::{
2 AssistantContext, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
3 SavedContextMetadata,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
7use assistant_tool::{ToolId, ToolWorkingSet};
8use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
9use clock::ReplicaId;
10use collections::HashMap;
11use context_server::manager::ContextServerManager;
12use context_server::{ContextServerFactoryRegistry, ContextServerTool};
13use fs::Fs;
14use futures::StreamExt;
15use fuzzy::StringMatchCandidate;
16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
17use language::LanguageRegistry;
18use paths::contexts_dir;
19use project::Project;
20use prompt_library::PromptBuilder;
21use regex::Regex;
22use rpc::AnyProtoClient;
23use std::sync::LazyLock;
24use std::{
25 cmp::Reverse,
26 ffi::OsStr,
27 mem,
28 path::{Path, PathBuf},
29 sync::Arc,
30 time::Duration,
31};
32use util::{ResultExt, TryFutureExt};
33
34pub(crate) fn init(client: &AnyProtoClient) {
35 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
36 client.add_model_request_handler(ContextStore::handle_open_context);
37 client.add_model_request_handler(ContextStore::handle_create_context);
38 client.add_model_message_handler(ContextStore::handle_update_context);
39 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
40}
41
42#[derive(Clone)]
43pub struct RemoteContextMetadata {
44 pub id: ContextId,
45 pub summary: Option<String>,
46}
47
48pub struct ContextStore {
49 contexts: Vec<ContextHandle>,
50 contexts_metadata: Vec<SavedContextMetadata>,
51 context_server_manager: Entity<ContextServerManager>,
52 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
53 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
54 host_contexts: Vec<RemoteContextMetadata>,
55 fs: Arc<dyn Fs>,
56 languages: Arc<LanguageRegistry>,
57 slash_commands: Arc<SlashCommandWorkingSet>,
58 tools: Arc<ToolWorkingSet>,
59 telemetry: Arc<Telemetry>,
60 _watch_updates: Task<Option<()>>,
61 client: Arc<Client>,
62 project: Entity<Project>,
63 project_is_shared: bool,
64 client_subscription: Option<client::Subscription>,
65 _project_subscriptions: Vec<gpui::Subscription>,
66 prompt_builder: Arc<PromptBuilder>,
67}
68
69pub enum ContextStoreEvent {
70 ContextCreated(ContextId),
71}
72
73impl EventEmitter<ContextStoreEvent> for ContextStore {}
74
75enum ContextHandle {
76 Weak(WeakEntity<AssistantContext>),
77 Strong(Entity<AssistantContext>),
78}
79
80impl ContextHandle {
81 fn upgrade(&self) -> Option<Entity<AssistantContext>> {
82 match self {
83 ContextHandle::Weak(weak) => weak.upgrade(),
84 ContextHandle::Strong(strong) => Some(strong.clone()),
85 }
86 }
87
88 fn downgrade(&self) -> WeakEntity<AssistantContext> {
89 match self {
90 ContextHandle::Weak(weak) => weak.clone(),
91 ContextHandle::Strong(strong) => strong.downgrade(),
92 }
93 }
94}
95
96impl ContextStore {
97 pub fn new(
98 project: Entity<Project>,
99 prompt_builder: Arc<PromptBuilder>,
100 slash_commands: Arc<SlashCommandWorkingSet>,
101 tools: Arc<ToolWorkingSet>,
102 cx: &mut App,
103 ) -> Task<Result<Entity<Self>>> {
104 let fs = project.read(cx).fs().clone();
105 let languages = project.read(cx).languages().clone();
106 let telemetry = project.read(cx).client().telemetry().clone();
107 cx.spawn(|mut cx| async move {
108 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
109 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
110
111 let this = cx.new(|cx: &mut Context<Self>| {
112 let context_server_factory_registry =
113 ContextServerFactoryRegistry::default_global(cx);
114 let context_server_manager = cx.new(|cx| {
115 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
116 });
117 let mut this = Self {
118 contexts: Vec::new(),
119 contexts_metadata: Vec::new(),
120 context_server_manager,
121 context_server_slash_command_ids: HashMap::default(),
122 context_server_tool_ids: HashMap::default(),
123 host_contexts: Vec::new(),
124 fs,
125 languages,
126 slash_commands,
127 tools,
128 telemetry,
129 _watch_updates: cx.spawn(|this, mut cx| {
130 async move {
131 while events.next().await.is_some() {
132 this.update(&mut cx, |this, cx| this.reload(cx))?
133 .await
134 .log_err();
135 }
136 anyhow::Ok(())
137 }
138 .log_err()
139 }),
140 client_subscription: None,
141 _project_subscriptions: vec![
142 cx.observe(&project, Self::handle_project_changed),
143 cx.subscribe(&project, Self::handle_project_event),
144 ],
145 project_is_shared: false,
146 client: project.read(cx).client(),
147 project: project.clone(),
148 prompt_builder,
149 };
150 this.handle_project_changed(project.clone(), cx);
151 this.synchronize_contexts(cx);
152 this.register_context_server_handlers(cx);
153 this
154 })?;
155 this.update(&mut cx, |this, cx| this.reload(cx))?
156 .await
157 .log_err();
158
159 Ok(this)
160 })
161 }
162
163 async fn handle_advertise_contexts(
164 this: Entity<Self>,
165 envelope: TypedEnvelope<proto::AdvertiseContexts>,
166 mut cx: AsyncApp,
167 ) -> Result<()> {
168 this.update(&mut cx, |this, cx| {
169 this.host_contexts = envelope
170 .payload
171 .contexts
172 .into_iter()
173 .map(|context| RemoteContextMetadata {
174 id: ContextId::from_proto(context.context_id),
175 summary: context.summary,
176 })
177 .collect();
178 cx.notify();
179 })
180 }
181
182 async fn handle_open_context(
183 this: Entity<Self>,
184 envelope: TypedEnvelope<proto::OpenContext>,
185 mut cx: AsyncApp,
186 ) -> Result<proto::OpenContextResponse> {
187 let context_id = ContextId::from_proto(envelope.payload.context_id);
188 let operations = this.update(&mut cx, |this, cx| {
189 if this.project.read(cx).is_via_collab() {
190 return Err(anyhow!("only the host contexts can be opened"));
191 }
192
193 let context = this
194 .loaded_context_for_id(&context_id, cx)
195 .context("context not found")?;
196 if context.read(cx).replica_id() != ReplicaId::default() {
197 return Err(anyhow!("context must be opened via the host"));
198 }
199
200 anyhow::Ok(
201 context
202 .read(cx)
203 .serialize_ops(&ContextVersion::default(), cx),
204 )
205 })??;
206 let operations = operations.await;
207 Ok(proto::OpenContextResponse {
208 context: Some(proto::Context { operations }),
209 })
210 }
211
212 async fn handle_create_context(
213 this: Entity<Self>,
214 _: TypedEnvelope<proto::CreateContext>,
215 mut cx: AsyncApp,
216 ) -> Result<proto::CreateContextResponse> {
217 let (context_id, operations) = this.update(&mut cx, |this, cx| {
218 if this.project.read(cx).is_via_collab() {
219 return Err(anyhow!("can only create contexts as the host"));
220 }
221
222 let context = this.create(cx);
223 let context_id = context.read(cx).id().clone();
224 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
225
226 anyhow::Ok((
227 context_id,
228 context
229 .read(cx)
230 .serialize_ops(&ContextVersion::default(), cx),
231 ))
232 })??;
233 let operations = operations.await;
234 Ok(proto::CreateContextResponse {
235 context_id: context_id.to_proto(),
236 context: Some(proto::Context { operations }),
237 })
238 }
239
240 async fn handle_update_context(
241 this: Entity<Self>,
242 envelope: TypedEnvelope<proto::UpdateContext>,
243 mut cx: AsyncApp,
244 ) -> Result<()> {
245 this.update(&mut cx, |this, cx| {
246 let context_id = ContextId::from_proto(envelope.payload.context_id);
247 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
248 let operation_proto = envelope.payload.operation.context("invalid operation")?;
249 let operation = ContextOperation::from_proto(operation_proto)?;
250 context.update(cx, |context, cx| context.apply_ops([operation], cx));
251 }
252 Ok(())
253 })?
254 }
255
256 async fn handle_synchronize_contexts(
257 this: Entity<Self>,
258 envelope: TypedEnvelope<proto::SynchronizeContexts>,
259 mut cx: AsyncApp,
260 ) -> Result<proto::SynchronizeContextsResponse> {
261 this.update(&mut cx, |this, cx| {
262 if this.project.read(cx).is_via_collab() {
263 return Err(anyhow!("only the host can synchronize contexts"));
264 }
265
266 let mut local_versions = Vec::new();
267 for remote_version_proto in envelope.payload.contexts {
268 let remote_version = ContextVersion::from_proto(&remote_version_proto);
269 let context_id = ContextId::from_proto(remote_version_proto.context_id);
270 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
271 let context = context.read(cx);
272 let operations = context.serialize_ops(&remote_version, cx);
273 local_versions.push(context.version(cx).to_proto(context_id.clone()));
274 let client = this.client.clone();
275 let project_id = envelope.payload.project_id;
276 cx.background_executor()
277 .spawn(async move {
278 let operations = operations.await;
279 for operation in operations {
280 client.send(proto::UpdateContext {
281 project_id,
282 context_id: context_id.to_proto(),
283 operation: Some(operation),
284 })?;
285 }
286 anyhow::Ok(())
287 })
288 .detach_and_log_err(cx);
289 }
290 }
291
292 this.advertise_contexts(cx);
293
294 anyhow::Ok(proto::SynchronizeContextsResponse {
295 contexts: local_versions,
296 })
297 })?
298 }
299
300 fn handle_project_changed(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
301 let is_shared = self.project.read(cx).is_shared();
302 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
303 if is_shared == was_shared {
304 return;
305 }
306
307 if is_shared {
308 self.contexts.retain_mut(|context| {
309 if let Some(strong_context) = context.upgrade() {
310 *context = ContextHandle::Strong(strong_context);
311 true
312 } else {
313 false
314 }
315 });
316 let remote_id = self.project.read(cx).remote_id().unwrap();
317 self.client_subscription = self
318 .client
319 .subscribe_to_entity(remote_id)
320 .log_err()
321 .map(|subscription| subscription.set_model(&cx.entity(), &mut cx.to_async()));
322 self.advertise_contexts(cx);
323 } else {
324 self.client_subscription = None;
325 }
326 }
327
328 fn handle_project_event(
329 &mut self,
330 _: Entity<Project>,
331 event: &project::Event,
332 cx: &mut Context<Self>,
333 ) {
334 match event {
335 project::Event::Reshared => {
336 self.advertise_contexts(cx);
337 }
338 project::Event::HostReshared | project::Event::Rejoined => {
339 self.synchronize_contexts(cx);
340 }
341 project::Event::DisconnectedFromHost => {
342 self.contexts.retain_mut(|context| {
343 if let Some(strong_context) = context.upgrade() {
344 *context = ContextHandle::Weak(context.downgrade());
345 strong_context.update(cx, |context, cx| {
346 if context.replica_id() != ReplicaId::default() {
347 context.set_capability(language::Capability::ReadOnly, cx);
348 }
349 });
350 true
351 } else {
352 false
353 }
354 });
355 self.host_contexts.clear();
356 cx.notify();
357 }
358 _ => {}
359 }
360 }
361
362 pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<AssistantContext> {
363 let context = cx.new(|cx| {
364 AssistantContext::local(
365 self.languages.clone(),
366 Some(self.project.clone()),
367 Some(self.telemetry.clone()),
368 self.prompt_builder.clone(),
369 self.slash_commands.clone(),
370 self.tools.clone(),
371 cx,
372 )
373 });
374 self.register_context(&context, cx);
375 context
376 }
377
378 pub fn create_remote_context(
379 &mut self,
380 cx: &mut Context<Self>,
381 ) -> Task<Result<Entity<AssistantContext>>> {
382 let project = self.project.read(cx);
383 let Some(project_id) = project.remote_id() else {
384 return Task::ready(Err(anyhow!("project was not remote")));
385 };
386
387 let replica_id = project.replica_id();
388 let capability = project.capability();
389 let language_registry = self.languages.clone();
390 let project = self.project.clone();
391 let telemetry = self.telemetry.clone();
392 let prompt_builder = self.prompt_builder.clone();
393 let slash_commands = self.slash_commands.clone();
394 let tools = self.tools.clone();
395 let request = self.client.request(proto::CreateContext { project_id });
396 cx.spawn(|this, mut cx| async move {
397 let response = request.await?;
398 let context_id = ContextId::from_proto(response.context_id);
399 let context_proto = response.context.context("invalid context")?;
400 let context = cx.new(|cx| {
401 AssistantContext::new(
402 context_id.clone(),
403 replica_id,
404 capability,
405 language_registry,
406 prompt_builder,
407 slash_commands,
408 tools,
409 Some(project),
410 Some(telemetry),
411 cx,
412 )
413 })?;
414 let operations = cx
415 .background_executor()
416 .spawn(async move {
417 context_proto
418 .operations
419 .into_iter()
420 .map(ContextOperation::from_proto)
421 .collect::<Result<Vec<_>>>()
422 })
423 .await?;
424 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
425 this.update(&mut cx, |this, cx| {
426 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
427 existing_context
428 } else {
429 this.register_context(&context, cx);
430 this.synchronize_contexts(cx);
431 context
432 }
433 })
434 })
435 }
436
437 pub fn open_local_context(
438 &mut self,
439 path: PathBuf,
440 cx: &Context<Self>,
441 ) -> Task<Result<Entity<AssistantContext>>> {
442 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
443 return Task::ready(Ok(existing_context));
444 }
445
446 let fs = self.fs.clone();
447 let languages = self.languages.clone();
448 let project = self.project.clone();
449 let telemetry = self.telemetry.clone();
450 let load = cx.background_executor().spawn({
451 let path = path.clone();
452 async move {
453 let saved_context = fs.load(&path).await?;
454 SavedContext::from_json(&saved_context)
455 }
456 });
457 let prompt_builder = self.prompt_builder.clone();
458 let slash_commands = self.slash_commands.clone();
459 let tools = self.tools.clone();
460
461 cx.spawn(|this, mut cx| async move {
462 let saved_context = load.await?;
463 let context = cx.new(|cx| {
464 AssistantContext::deserialize(
465 saved_context,
466 path.clone(),
467 languages,
468 prompt_builder,
469 slash_commands,
470 tools,
471 Some(project),
472 Some(telemetry),
473 cx,
474 )
475 })?;
476 this.update(&mut cx, |this, cx| {
477 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
478 existing_context
479 } else {
480 this.register_context(&context, cx);
481 context
482 }
483 })
484 })
485 }
486
487 fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
488 self.contexts.iter().find_map(|context| {
489 let context = context.upgrade()?;
490 if context.read(cx).path() == Some(path) {
491 Some(context)
492 } else {
493 None
494 }
495 })
496 }
497
498 pub fn loaded_context_for_id(
499 &self,
500 id: &ContextId,
501 cx: &App,
502 ) -> Option<Entity<AssistantContext>> {
503 self.contexts.iter().find_map(|context| {
504 let context = context.upgrade()?;
505 if context.read(cx).id() == id {
506 Some(context)
507 } else {
508 None
509 }
510 })
511 }
512
513 pub fn open_remote_context(
514 &mut self,
515 context_id: ContextId,
516 cx: &mut Context<Self>,
517 ) -> Task<Result<Entity<AssistantContext>>> {
518 let project = self.project.read(cx);
519 let Some(project_id) = project.remote_id() else {
520 return Task::ready(Err(anyhow!("project was not remote")));
521 };
522
523 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
524 return Task::ready(Ok(context));
525 }
526
527 let replica_id = project.replica_id();
528 let capability = project.capability();
529 let language_registry = self.languages.clone();
530 let project = self.project.clone();
531 let telemetry = self.telemetry.clone();
532 let request = self.client.request(proto::OpenContext {
533 project_id,
534 context_id: context_id.to_proto(),
535 });
536 let prompt_builder = self.prompt_builder.clone();
537 let slash_commands = self.slash_commands.clone();
538 let tools = self.tools.clone();
539 cx.spawn(|this, mut cx| async move {
540 let response = request.await?;
541 let context_proto = response.context.context("invalid context")?;
542 let context = cx.new(|cx| {
543 AssistantContext::new(
544 context_id.clone(),
545 replica_id,
546 capability,
547 language_registry,
548 prompt_builder,
549 slash_commands,
550 tools,
551 Some(project),
552 Some(telemetry),
553 cx,
554 )
555 })?;
556 let operations = cx
557 .background_executor()
558 .spawn(async move {
559 context_proto
560 .operations
561 .into_iter()
562 .map(ContextOperation::from_proto)
563 .collect::<Result<Vec<_>>>()
564 })
565 .await?;
566 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
567 this.update(&mut cx, |this, cx| {
568 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
569 existing_context
570 } else {
571 this.register_context(&context, cx);
572 this.synchronize_contexts(cx);
573 context
574 }
575 })
576 })
577 }
578
579 fn register_context(&mut self, context: &Entity<AssistantContext>, cx: &mut Context<Self>) {
580 let handle = if self.project_is_shared {
581 ContextHandle::Strong(context.clone())
582 } else {
583 ContextHandle::Weak(context.downgrade())
584 };
585 self.contexts.push(handle);
586 self.advertise_contexts(cx);
587 cx.subscribe(context, Self::handle_context_event).detach();
588 }
589
590 fn handle_context_event(
591 &mut self,
592 context: Entity<AssistantContext>,
593 event: &ContextEvent,
594 cx: &mut Context<Self>,
595 ) {
596 let Some(project_id) = self.project.read(cx).remote_id() else {
597 return;
598 };
599
600 match event {
601 ContextEvent::SummaryChanged => {
602 self.advertise_contexts(cx);
603 }
604 ContextEvent::Operation(operation) => {
605 let context_id = context.read(cx).id().to_proto();
606 let operation = operation.to_proto();
607 self.client
608 .send(proto::UpdateContext {
609 project_id,
610 context_id,
611 operation: Some(operation),
612 })
613 .log_err();
614 }
615 _ => {}
616 }
617 }
618
619 fn advertise_contexts(&self, cx: &App) {
620 let Some(project_id) = self.project.read(cx).remote_id() else {
621 return;
622 };
623
624 // For now, only the host can advertise their open contexts.
625 if self.project.read(cx).is_via_collab() {
626 return;
627 }
628
629 let contexts = self
630 .contexts
631 .iter()
632 .rev()
633 .filter_map(|context| {
634 let context = context.upgrade()?.read(cx);
635 if context.replica_id() == ReplicaId::default() {
636 Some(proto::ContextMetadata {
637 context_id: context.id().to_proto(),
638 summary: context.summary().map(|summary| summary.text.clone()),
639 })
640 } else {
641 None
642 }
643 })
644 .collect();
645 self.client
646 .send(proto::AdvertiseContexts {
647 project_id,
648 contexts,
649 })
650 .ok();
651 }
652
653 fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
654 let Some(project_id) = self.project.read(cx).remote_id() else {
655 return;
656 };
657
658 let contexts = self
659 .contexts
660 .iter()
661 .filter_map(|context| {
662 let context = context.upgrade()?.read(cx);
663 if context.replica_id() != ReplicaId::default() {
664 Some(context.version(cx).to_proto(context.id().clone()))
665 } else {
666 None
667 }
668 })
669 .collect();
670
671 let client = self.client.clone();
672 let request = self.client.request(proto::SynchronizeContexts {
673 project_id,
674 contexts,
675 });
676 cx.spawn(|this, cx| async move {
677 let response = request.await?;
678
679 let mut context_ids = Vec::new();
680 let mut operations = Vec::new();
681 this.read_with(&cx, |this, cx| {
682 for context_version_proto in response.contexts {
683 let context_version = ContextVersion::from_proto(&context_version_proto);
684 let context_id = ContextId::from_proto(context_version_proto.context_id);
685 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
686 context_ids.push(context_id);
687 operations.push(context.read(cx).serialize_ops(&context_version, cx));
688 }
689 }
690 })?;
691
692 let operations = futures::future::join_all(operations).await;
693 for (context_id, operations) in context_ids.into_iter().zip(operations) {
694 for operation in operations {
695 client.send(proto::UpdateContext {
696 project_id,
697 context_id: context_id.to_proto(),
698 operation: Some(operation),
699 })?;
700 }
701 }
702
703 anyhow::Ok(())
704 })
705 .detach_and_log_err(cx);
706 }
707
708 pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedContextMetadata>> {
709 let metadata = self.contexts_metadata.clone();
710 let executor = cx.background_executor().clone();
711 cx.background_executor().spawn(async move {
712 if query.is_empty() {
713 metadata
714 } else {
715 let candidates = metadata
716 .iter()
717 .enumerate()
718 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
719 .collect::<Vec<_>>();
720 let matches = fuzzy::match_strings(
721 &candidates,
722 &query,
723 false,
724 100,
725 &Default::default(),
726 executor,
727 )
728 .await;
729
730 matches
731 .into_iter()
732 .map(|mat| metadata[mat.candidate_id].clone())
733 .collect()
734 }
735 })
736 }
737
738 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
739 &self.host_contexts
740 }
741
742 fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
743 let fs = self.fs.clone();
744 cx.spawn(|this, mut cx| async move {
745 fs.create_dir(contexts_dir()).await?;
746
747 let mut paths = fs.read_dir(contexts_dir()).await?;
748 let mut contexts = Vec::<SavedContextMetadata>::new();
749 while let Some(path) = paths.next().await {
750 let path = path?;
751 if path.extension() != Some(OsStr::new("json")) {
752 continue;
753 }
754
755 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
756 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
757
758 let metadata = fs.metadata(&path).await?;
759 if let Some((file_name, metadata)) = path
760 .file_name()
761 .and_then(|name| name.to_str())
762 .zip(metadata)
763 {
764 // This is used to filter out contexts saved by the new assistant.
765 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
766 continue;
767 }
768
769 if let Some(title) = ASSISTANT_CONTEXT_REGEX
770 .replace(file_name, "")
771 .lines()
772 .next()
773 {
774 contexts.push(SavedContextMetadata {
775 title: title.to_string(),
776 path,
777 mtime: metadata.mtime.timestamp_for_user().into(),
778 });
779 }
780 }
781 }
782 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
783
784 this.update(&mut cx, |this, cx| {
785 this.contexts_metadata = contexts;
786 cx.notify();
787 })
788 })
789 }
790
791 pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
792 cx.update_entity(
793 &self.context_server_manager,
794 |context_server_manager, cx| {
795 for server in context_server_manager.servers() {
796 context_server_manager
797 .restart_server(&server.id(), cx)
798 .detach_and_log_err(cx);
799 }
800 },
801 );
802 }
803
804 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
805 cx.subscribe(
806 &self.context_server_manager.clone(),
807 Self::handle_context_server_event,
808 )
809 .detach();
810 }
811
812 fn handle_context_server_event(
813 &mut self,
814 context_server_manager: Entity<ContextServerManager>,
815 event: &context_server::manager::Event,
816 cx: &mut Context<Self>,
817 ) {
818 let slash_command_working_set = self.slash_commands.clone();
819 let tool_working_set = self.tools.clone();
820 match event {
821 context_server::manager::Event::ServerStarted { server_id } => {
822 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
823 let context_server_manager = context_server_manager.clone();
824 cx.spawn({
825 let server = server.clone();
826 let server_id = server_id.clone();
827 |this, mut cx| async move {
828 let Some(protocol) = server.client() else {
829 return;
830 };
831
832 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
833 if let Some(prompts) = protocol.list_prompts().await.log_err() {
834 let slash_command_ids = prompts
835 .into_iter()
836 .filter(assistant_slash_commands::acceptable_prompt)
837 .map(|prompt| {
838 log::info!(
839 "registering context server command: {:?}",
840 prompt.name
841 );
842 slash_command_working_set.insert(Arc::new(
843 assistant_slash_commands::ContextServerSlashCommand::new(
844 context_server_manager.clone(),
845 &server,
846 prompt,
847 ),
848 ))
849 })
850 .collect::<Vec<_>>();
851
852 this.update(&mut cx, |this, _cx| {
853 this.context_server_slash_command_ids
854 .insert(server_id.clone(), slash_command_ids);
855 })
856 .log_err();
857 }
858 }
859
860 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
861 if let Some(tools) = protocol.list_tools().await.log_err() {
862 let tool_ids = tools.tools.into_iter().map(|tool| {
863 log::info!("registering context server tool: {:?}", tool.name);
864 tool_working_set.insert(
865 Arc::new(ContextServerTool::new(
866 context_server_manager.clone(),
867 server.id(),
868 tool,
869 )),
870 )
871
872 }).collect::<Vec<_>>();
873
874
875 this.update(&mut cx, |this, _cx| {
876 this.context_server_tool_ids
877 .insert(server_id, tool_ids);
878 })
879 .log_err();
880 }
881 }
882 }
883 })
884 .detach();
885 }
886 }
887 context_server::manager::Event::ServerStopped { server_id } => {
888 if let Some(slash_command_ids) =
889 self.context_server_slash_command_ids.remove(server_id)
890 {
891 slash_command_working_set.remove(&slash_command_ids);
892 }
893
894 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
895 tool_working_set.remove(&tool_ids);
896 }
897 }
898 }
899 }
900}