1use crate::slash_command::context_server_command;
2use crate::SlashCommandId;
3use crate::{
4 prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
5 ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
6};
7use anyhow::{anyhow, Context as _, Result};
8use assistant_tool::{ToolId, ToolWorkingSet};
9use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
10use clock::ReplicaId;
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::StreamExt;
16use fuzzy::StringMatchCandidate;
17use gpui::{
18 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
19};
20use language::LanguageRegistry;
21use paths::contexts_dir;
22use project::Project;
23use regex::Regex;
24use rpc::AnyProtoClient;
25use std::sync::LazyLock;
26use std::{
27 cmp::Reverse,
28 ffi::OsStr,
29 mem,
30 path::{Path, PathBuf},
31 sync::Arc,
32 time::Duration,
33};
34use util::{ResultExt, TryFutureExt};
35
36pub fn init(client: &AnyProtoClient) {
37 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
38 client.add_model_request_handler(ContextStore::handle_open_context);
39 client.add_model_request_handler(ContextStore::handle_create_context);
40 client.add_model_message_handler(ContextStore::handle_update_context);
41 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
42}
43
44#[derive(Clone)]
45pub struct RemoteContextMetadata {
46 pub id: ContextId,
47 pub summary: Option<String>,
48}
49
50pub struct ContextStore {
51 contexts: Vec<ContextHandle>,
52 contexts_metadata: Vec<SavedContextMetadata>,
53 context_server_manager: Model<ContextServerManager>,
54 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
55 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
56 host_contexts: Vec<RemoteContextMetadata>,
57 fs: Arc<dyn Fs>,
58 languages: Arc<LanguageRegistry>,
59 slash_commands: Arc<SlashCommandWorkingSet>,
60 tools: Arc<ToolWorkingSet>,
61 telemetry: Arc<Telemetry>,
62 _watch_updates: Task<Option<()>>,
63 client: Arc<Client>,
64 project: Model<Project>,
65 project_is_shared: bool,
66 client_subscription: Option<client::Subscription>,
67 _project_subscriptions: Vec<gpui::Subscription>,
68 prompt_builder: Arc<PromptBuilder>,
69}
70
71pub enum ContextStoreEvent {
72 ContextCreated(ContextId),
73}
74
75impl EventEmitter<ContextStoreEvent> for ContextStore {}
76
77enum ContextHandle {
78 Weak(WeakModel<Context>),
79 Strong(Model<Context>),
80}
81
82impl ContextHandle {
83 fn upgrade(&self) -> Option<Model<Context>> {
84 match self {
85 ContextHandle::Weak(weak) => weak.upgrade(),
86 ContextHandle::Strong(strong) => Some(strong.clone()),
87 }
88 }
89
90 fn downgrade(&self) -> WeakModel<Context> {
91 match self {
92 ContextHandle::Weak(weak) => weak.clone(),
93 ContextHandle::Strong(strong) => strong.downgrade(),
94 }
95 }
96}
97
98impl ContextStore {
99 pub fn new(
100 project: Model<Project>,
101 prompt_builder: Arc<PromptBuilder>,
102 slash_commands: Arc<SlashCommandWorkingSet>,
103 tools: Arc<ToolWorkingSet>,
104 cx: &mut AppContext,
105 ) -> Task<Result<Model<Self>>> {
106 let fs = project.read(cx).fs().clone();
107 let languages = project.read(cx).languages().clone();
108 let telemetry = project.read(cx).client().telemetry().clone();
109 cx.spawn(|mut cx| async move {
110 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
111 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
112
113 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
114 let context_server_factory_registry =
115 ContextServerFactoryRegistry::default_global(cx);
116 let context_server_manager = cx.new_model(|cx| {
117 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
118 });
119 let mut this = Self {
120 contexts: Vec::new(),
121 contexts_metadata: Vec::new(),
122 context_server_manager,
123 context_server_slash_command_ids: HashMap::default(),
124 context_server_tool_ids: HashMap::default(),
125 host_contexts: Vec::new(),
126 fs,
127 languages,
128 slash_commands,
129 tools,
130 telemetry,
131 _watch_updates: cx.spawn(|this, mut cx| {
132 async move {
133 while events.next().await.is_some() {
134 this.update(&mut cx, |this, cx| this.reload(cx))?
135 .await
136 .log_err();
137 }
138 anyhow::Ok(())
139 }
140 .log_err()
141 }),
142 client_subscription: None,
143 _project_subscriptions: vec![
144 cx.observe(&project, Self::handle_project_changed),
145 cx.subscribe(&project, Self::handle_project_event),
146 ],
147 project_is_shared: false,
148 client: project.read(cx).client(),
149 project: project.clone(),
150 prompt_builder,
151 };
152 this.handle_project_changed(project.clone(), cx);
153 this.synchronize_contexts(cx);
154 this.register_context_server_handlers(cx);
155 this
156 })?;
157 this.update(&mut cx, |this, cx| this.reload(cx))?
158 .await
159 .log_err();
160
161 Ok(this)
162 })
163 }
164
165 async fn handle_advertise_contexts(
166 this: Model<Self>,
167 envelope: TypedEnvelope<proto::AdvertiseContexts>,
168 mut cx: AsyncAppContext,
169 ) -> Result<()> {
170 this.update(&mut cx, |this, cx| {
171 this.host_contexts = envelope
172 .payload
173 .contexts
174 .into_iter()
175 .map(|context| RemoteContextMetadata {
176 id: ContextId::from_proto(context.context_id),
177 summary: context.summary,
178 })
179 .collect();
180 cx.notify();
181 })
182 }
183
184 async fn handle_open_context(
185 this: Model<Self>,
186 envelope: TypedEnvelope<proto::OpenContext>,
187 mut cx: AsyncAppContext,
188 ) -> Result<proto::OpenContextResponse> {
189 let context_id = ContextId::from_proto(envelope.payload.context_id);
190 let operations = this.update(&mut cx, |this, cx| {
191 if this.project.read(cx).is_via_collab() {
192 return Err(anyhow!("only the host contexts can be opened"));
193 }
194
195 let context = this
196 .loaded_context_for_id(&context_id, cx)
197 .context("context not found")?;
198 if context.read(cx).replica_id() != ReplicaId::default() {
199 return Err(anyhow!("context must be opened via the host"));
200 }
201
202 anyhow::Ok(
203 context
204 .read(cx)
205 .serialize_ops(&ContextVersion::default(), cx),
206 )
207 })??;
208 let operations = operations.await;
209 Ok(proto::OpenContextResponse {
210 context: Some(proto::Context { operations }),
211 })
212 }
213
214 async fn handle_create_context(
215 this: Model<Self>,
216 _: TypedEnvelope<proto::CreateContext>,
217 mut cx: AsyncAppContext,
218 ) -> Result<proto::CreateContextResponse> {
219 let (context_id, operations) = this.update(&mut cx, |this, cx| {
220 if this.project.read(cx).is_via_collab() {
221 return Err(anyhow!("can only create contexts as the host"));
222 }
223
224 let context = this.create(cx);
225 let context_id = context.read(cx).id().clone();
226 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
227
228 anyhow::Ok((
229 context_id,
230 context
231 .read(cx)
232 .serialize_ops(&ContextVersion::default(), cx),
233 ))
234 })??;
235 let operations = operations.await;
236 Ok(proto::CreateContextResponse {
237 context_id: context_id.to_proto(),
238 context: Some(proto::Context { operations }),
239 })
240 }
241
242 async fn handle_update_context(
243 this: Model<Self>,
244 envelope: TypedEnvelope<proto::UpdateContext>,
245 mut cx: AsyncAppContext,
246 ) -> Result<()> {
247 this.update(&mut cx, |this, cx| {
248 let context_id = ContextId::from_proto(envelope.payload.context_id);
249 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
250 let operation_proto = envelope.payload.operation.context("invalid operation")?;
251 let operation = ContextOperation::from_proto(operation_proto)?;
252 context.update(cx, |context, cx| context.apply_ops([operation], cx));
253 }
254 Ok(())
255 })?
256 }
257
258 async fn handle_synchronize_contexts(
259 this: Model<Self>,
260 envelope: TypedEnvelope<proto::SynchronizeContexts>,
261 mut cx: AsyncAppContext,
262 ) -> Result<proto::SynchronizeContextsResponse> {
263 this.update(&mut cx, |this, cx| {
264 if this.project.read(cx).is_via_collab() {
265 return Err(anyhow!("only the host can synchronize contexts"));
266 }
267
268 let mut local_versions = Vec::new();
269 for remote_version_proto in envelope.payload.contexts {
270 let remote_version = ContextVersion::from_proto(&remote_version_proto);
271 let context_id = ContextId::from_proto(remote_version_proto.context_id);
272 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
273 let context = context.read(cx);
274 let operations = context.serialize_ops(&remote_version, cx);
275 local_versions.push(context.version(cx).to_proto(context_id.clone()));
276 let client = this.client.clone();
277 let project_id = envelope.payload.project_id;
278 cx.background_executor()
279 .spawn(async move {
280 let operations = operations.await;
281 for operation in operations {
282 client.send(proto::UpdateContext {
283 project_id,
284 context_id: context_id.to_proto(),
285 operation: Some(operation),
286 })?;
287 }
288 anyhow::Ok(())
289 })
290 .detach_and_log_err(cx);
291 }
292 }
293
294 this.advertise_contexts(cx);
295
296 anyhow::Ok(proto::SynchronizeContextsResponse {
297 contexts: local_versions,
298 })
299 })?
300 }
301
302 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
303 let is_shared = self.project.read(cx).is_shared();
304 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
305 if is_shared == was_shared {
306 return;
307 }
308
309 if is_shared {
310 self.contexts.retain_mut(|context| {
311 if let Some(strong_context) = context.upgrade() {
312 *context = ContextHandle::Strong(strong_context);
313 true
314 } else {
315 false
316 }
317 });
318 let remote_id = self.project.read(cx).remote_id().unwrap();
319 self.client_subscription = self
320 .client
321 .subscribe_to_entity(remote_id)
322 .log_err()
323 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
324 self.advertise_contexts(cx);
325 } else {
326 self.client_subscription = None;
327 }
328 }
329
330 fn handle_project_event(
331 &mut self,
332 _: Model<Project>,
333 event: &project::Event,
334 cx: &mut ModelContext<Self>,
335 ) {
336 match event {
337 project::Event::Reshared => {
338 self.advertise_contexts(cx);
339 }
340 project::Event::HostReshared | project::Event::Rejoined => {
341 self.synchronize_contexts(cx);
342 }
343 project::Event::DisconnectedFromHost => {
344 self.contexts.retain_mut(|context| {
345 if let Some(strong_context) = context.upgrade() {
346 *context = ContextHandle::Weak(context.downgrade());
347 strong_context.update(cx, |context, cx| {
348 if context.replica_id() != ReplicaId::default() {
349 context.set_capability(language::Capability::ReadOnly, cx);
350 }
351 });
352 true
353 } else {
354 false
355 }
356 });
357 self.host_contexts.clear();
358 cx.notify();
359 }
360 _ => {}
361 }
362 }
363
364 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
365 let context = cx.new_model(|cx| {
366 Context::local(
367 self.languages.clone(),
368 Some(self.project.clone()),
369 Some(self.telemetry.clone()),
370 self.prompt_builder.clone(),
371 self.slash_commands.clone(),
372 self.tools.clone(),
373 cx,
374 )
375 });
376 self.register_context(&context, cx);
377 context
378 }
379
380 pub fn create_remote_context(
381 &mut self,
382 cx: &mut ModelContext<Self>,
383 ) -> Task<Result<Model<Context>>> {
384 let project = self.project.read(cx);
385 let Some(project_id) = project.remote_id() else {
386 return Task::ready(Err(anyhow!("project was not remote")));
387 };
388
389 let replica_id = project.replica_id();
390 let capability = project.capability();
391 let language_registry = self.languages.clone();
392 let project = self.project.clone();
393 let telemetry = self.telemetry.clone();
394 let prompt_builder = self.prompt_builder.clone();
395 let slash_commands = self.slash_commands.clone();
396 let tools = self.tools.clone();
397 let request = self.client.request(proto::CreateContext { project_id });
398 cx.spawn(|this, mut cx| async move {
399 let response = request.await?;
400 let context_id = ContextId::from_proto(response.context_id);
401 let context_proto = response.context.context("invalid context")?;
402 let context = cx.new_model(|cx| {
403 Context::new(
404 context_id.clone(),
405 replica_id,
406 capability,
407 language_registry,
408 prompt_builder,
409 slash_commands,
410 tools,
411 Some(project),
412 Some(telemetry),
413 cx,
414 )
415 })?;
416 let operations = cx
417 .background_executor()
418 .spawn(async move {
419 context_proto
420 .operations
421 .into_iter()
422 .map(ContextOperation::from_proto)
423 .collect::<Result<Vec<_>>>()
424 })
425 .await?;
426 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
427 this.update(&mut cx, |this, cx| {
428 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
429 existing_context
430 } else {
431 this.register_context(&context, cx);
432 this.synchronize_contexts(cx);
433 context
434 }
435 })
436 })
437 }
438
439 pub fn open_local_context(
440 &mut self,
441 path: PathBuf,
442 cx: &ModelContext<Self>,
443 ) -> Task<Result<Model<Context>>> {
444 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
445 return Task::ready(Ok(existing_context));
446 }
447
448 let fs = self.fs.clone();
449 let languages = self.languages.clone();
450 let project = self.project.clone();
451 let telemetry = self.telemetry.clone();
452 let load = cx.background_executor().spawn({
453 let path = path.clone();
454 async move {
455 let saved_context = fs.load(&path).await?;
456 SavedContext::from_json(&saved_context)
457 }
458 });
459 let prompt_builder = self.prompt_builder.clone();
460 let slash_commands = self.slash_commands.clone();
461 let tools = self.tools.clone();
462
463 cx.spawn(|this, mut cx| async move {
464 let saved_context = load.await?;
465 let context = cx.new_model(|cx| {
466 Context::deserialize(
467 saved_context,
468 path.clone(),
469 languages,
470 prompt_builder,
471 slash_commands,
472 tools,
473 Some(project),
474 Some(telemetry),
475 cx,
476 )
477 })?;
478 this.update(&mut cx, |this, cx| {
479 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
480 existing_context
481 } else {
482 this.register_context(&context, cx);
483 context
484 }
485 })
486 })
487 }
488
489 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
490 self.contexts.iter().find_map(|context| {
491 let context = context.upgrade()?;
492 if context.read(cx).path() == Some(path) {
493 Some(context)
494 } else {
495 None
496 }
497 })
498 }
499
500 pub(super) fn loaded_context_for_id(
501 &self,
502 id: &ContextId,
503 cx: &AppContext,
504 ) -> Option<Model<Context>> {
505 self.contexts.iter().find_map(|context| {
506 let context = context.upgrade()?;
507 if context.read(cx).id() == id {
508 Some(context)
509 } else {
510 None
511 }
512 })
513 }
514
515 pub fn open_remote_context(
516 &mut self,
517 context_id: ContextId,
518 cx: &mut ModelContext<Self>,
519 ) -> Task<Result<Model<Context>>> {
520 let project = self.project.read(cx);
521 let Some(project_id) = project.remote_id() else {
522 return Task::ready(Err(anyhow!("project was not remote")));
523 };
524
525 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
526 return Task::ready(Ok(context));
527 }
528
529 let replica_id = project.replica_id();
530 let capability = project.capability();
531 let language_registry = self.languages.clone();
532 let project = self.project.clone();
533 let telemetry = self.telemetry.clone();
534 let request = self.client.request(proto::OpenContext {
535 project_id,
536 context_id: context_id.to_proto(),
537 });
538 let prompt_builder = self.prompt_builder.clone();
539 let slash_commands = self.slash_commands.clone();
540 let tools = self.tools.clone();
541 cx.spawn(|this, mut cx| async move {
542 let response = request.await?;
543 let context_proto = response.context.context("invalid context")?;
544 let context = cx.new_model(|cx| {
545 Context::new(
546 context_id.clone(),
547 replica_id,
548 capability,
549 language_registry,
550 prompt_builder,
551 slash_commands,
552 tools,
553 Some(project),
554 Some(telemetry),
555 cx,
556 )
557 })?;
558 let operations = cx
559 .background_executor()
560 .spawn(async move {
561 context_proto
562 .operations
563 .into_iter()
564 .map(ContextOperation::from_proto)
565 .collect::<Result<Vec<_>>>()
566 })
567 .await?;
568 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
569 this.update(&mut cx, |this, cx| {
570 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
571 existing_context
572 } else {
573 this.register_context(&context, cx);
574 this.synchronize_contexts(cx);
575 context
576 }
577 })
578 })
579 }
580
581 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
582 let handle = if self.project_is_shared {
583 ContextHandle::Strong(context.clone())
584 } else {
585 ContextHandle::Weak(context.downgrade())
586 };
587 self.contexts.push(handle);
588 self.advertise_contexts(cx);
589 cx.subscribe(context, Self::handle_context_event).detach();
590 }
591
592 fn handle_context_event(
593 &mut self,
594 context: Model<Context>,
595 event: &ContextEvent,
596 cx: &mut ModelContext<Self>,
597 ) {
598 let Some(project_id) = self.project.read(cx).remote_id() else {
599 return;
600 };
601
602 match event {
603 ContextEvent::SummaryChanged => {
604 self.advertise_contexts(cx);
605 }
606 ContextEvent::Operation(operation) => {
607 let context_id = context.read(cx).id().to_proto();
608 let operation = operation.to_proto();
609 self.client
610 .send(proto::UpdateContext {
611 project_id,
612 context_id,
613 operation: Some(operation),
614 })
615 .log_err();
616 }
617 _ => {}
618 }
619 }
620
621 fn advertise_contexts(&self, cx: &AppContext) {
622 let Some(project_id) = self.project.read(cx).remote_id() else {
623 return;
624 };
625
626 // For now, only the host can advertise their open contexts.
627 if self.project.read(cx).is_via_collab() {
628 return;
629 }
630
631 let contexts = self
632 .contexts
633 .iter()
634 .rev()
635 .filter_map(|context| {
636 let context = context.upgrade()?.read(cx);
637 if context.replica_id() == ReplicaId::default() {
638 Some(proto::ContextMetadata {
639 context_id: context.id().to_proto(),
640 summary: context.summary().map(|summary| summary.text.clone()),
641 })
642 } else {
643 None
644 }
645 })
646 .collect();
647 self.client
648 .send(proto::AdvertiseContexts {
649 project_id,
650 contexts,
651 })
652 .ok();
653 }
654
655 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
656 let Some(project_id) = self.project.read(cx).remote_id() else {
657 return;
658 };
659
660 let contexts = self
661 .contexts
662 .iter()
663 .filter_map(|context| {
664 let context = context.upgrade()?.read(cx);
665 if context.replica_id() != ReplicaId::default() {
666 Some(context.version(cx).to_proto(context.id().clone()))
667 } else {
668 None
669 }
670 })
671 .collect();
672
673 let client = self.client.clone();
674 let request = self.client.request(proto::SynchronizeContexts {
675 project_id,
676 contexts,
677 });
678 cx.spawn(|this, cx| async move {
679 let response = request.await?;
680
681 let mut context_ids = Vec::new();
682 let mut operations = Vec::new();
683 this.read_with(&cx, |this, cx| {
684 for context_version_proto in response.contexts {
685 let context_version = ContextVersion::from_proto(&context_version_proto);
686 let context_id = ContextId::from_proto(context_version_proto.context_id);
687 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
688 context_ids.push(context_id);
689 operations.push(context.read(cx).serialize_ops(&context_version, cx));
690 }
691 }
692 })?;
693
694 let operations = futures::future::join_all(operations).await;
695 for (context_id, operations) in context_ids.into_iter().zip(operations) {
696 for operation in operations {
697 client.send(proto::UpdateContext {
698 project_id,
699 context_id: context_id.to_proto(),
700 operation: Some(operation),
701 })?;
702 }
703 }
704
705 anyhow::Ok(())
706 })
707 .detach_and_log_err(cx);
708 }
709
710 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
711 let metadata = self.contexts_metadata.clone();
712 let executor = cx.background_executor().clone();
713 cx.background_executor().spawn(async move {
714 if query.is_empty() {
715 metadata
716 } else {
717 let candidates = metadata
718 .iter()
719 .enumerate()
720 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
721 .collect::<Vec<_>>();
722 let matches = fuzzy::match_strings(
723 &candidates,
724 &query,
725 false,
726 100,
727 &Default::default(),
728 executor,
729 )
730 .await;
731
732 matches
733 .into_iter()
734 .map(|mat| metadata[mat.candidate_id].clone())
735 .collect()
736 }
737 })
738 }
739
740 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
741 &self.host_contexts
742 }
743
744 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
745 let fs = self.fs.clone();
746 cx.spawn(|this, mut cx| async move {
747 fs.create_dir(contexts_dir()).await?;
748
749 let mut paths = fs.read_dir(contexts_dir()).await?;
750 let mut contexts = Vec::<SavedContextMetadata>::new();
751 while let Some(path) = paths.next().await {
752 let path = path?;
753 if path.extension() != Some(OsStr::new("json")) {
754 continue;
755 }
756
757 static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
758 LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
759
760 let metadata = fs.metadata(&path).await?;
761 if let Some((file_name, metadata)) = path
762 .file_name()
763 .and_then(|name| name.to_str())
764 .zip(metadata)
765 {
766 // This is used to filter out contexts saved by the new assistant.
767 if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
768 continue;
769 }
770
771 if let Some(title) = ASSISTANT_CONTEXT_REGEX
772 .replace(file_name, "")
773 .lines()
774 .next()
775 {
776 contexts.push(SavedContextMetadata {
777 title: title.to_string(),
778 path,
779 mtime: metadata.mtime.timestamp_for_user().into(),
780 });
781 }
782 }
783 }
784 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
785
786 this.update(&mut cx, |this, cx| {
787 this.contexts_metadata = contexts;
788 cx.notify();
789 })
790 })
791 }
792
793 pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
794 cx.update_model(
795 &self.context_server_manager,
796 |context_server_manager, cx| {
797 for server in context_server_manager.servers() {
798 context_server_manager
799 .restart_server(&server.id(), cx)
800 .detach_and_log_err(cx);
801 }
802 },
803 );
804 }
805
806 fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
807 cx.subscribe(
808 &self.context_server_manager.clone(),
809 Self::handle_context_server_event,
810 )
811 .detach();
812 }
813
814 fn handle_context_server_event(
815 &mut self,
816 context_server_manager: Model<ContextServerManager>,
817 event: &context_server::manager::Event,
818 cx: &mut ModelContext<Self>,
819 ) {
820 let slash_command_working_set = self.slash_commands.clone();
821 let tool_working_set = self.tools.clone();
822 match event {
823 context_server::manager::Event::ServerStarted { server_id } => {
824 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
825 let context_server_manager = context_server_manager.clone();
826 cx.spawn({
827 let server = server.clone();
828 let server_id = server_id.clone();
829 |this, mut cx| async move {
830 let Some(protocol) = server.client() else {
831 return;
832 };
833
834 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
835 if let Some(prompts) = protocol.list_prompts().await.log_err() {
836 let slash_command_ids = prompts
837 .into_iter()
838 .filter(context_server_command::acceptable_prompt)
839 .map(|prompt| {
840 log::info!(
841 "registering context server command: {:?}",
842 prompt.name
843 );
844 slash_command_working_set.insert(Arc::new(
845 context_server_command::ContextServerSlashCommand::new(
846 context_server_manager.clone(),
847 &server,
848 prompt,
849 ),
850 ))
851 })
852 .collect::<Vec<_>>();
853
854 this.update(&mut cx, |this, _cx| {
855 this.context_server_slash_command_ids
856 .insert(server_id.clone(), slash_command_ids);
857 })
858 .log_err();
859 }
860 }
861
862 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
863 if let Some(tools) = protocol.list_tools().await.log_err() {
864 let tool_ids = tools.tools.into_iter().map(|tool| {
865 log::info!("registering context server tool: {:?}", tool.name);
866 tool_working_set.insert(
867 Arc::new(ContextServerTool::new(
868 context_server_manager.clone(),
869 server.id(),
870 tool,
871 )),
872 )
873
874 }).collect::<Vec<_>>();
875
876
877 this.update(&mut cx, |this, _cx| {
878 this.context_server_tool_ids
879 .insert(server_id, tool_ids);
880 })
881 .log_err();
882 }
883 }
884 }
885 })
886 .detach();
887 }
888 }
889 context_server::manager::Event::ServerStopped { server_id } => {
890 if let Some(slash_command_ids) =
891 self.context_server_slash_command_ids.remove(server_id)
892 {
893 slash_command_working_set.remove(&slash_command_ids);
894 }
895
896 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
897 tool_working_set.remove(&tool_ids);
898 }
899 }
900 }
901 }
902}