1use crate::slash_command::context_server_command;
2use crate::SlashCommandId;
3use crate::{
4 prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
5 ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
6};
7use anyhow::{anyhow, Context as _, Result};
8use assistant_tool::{ToolId, ToolWorkingSet};
9use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
10use clock::ReplicaId;
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::StreamExt;
16use fuzzy::StringMatchCandidate;
17use gpui::{
18 AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
19};
20use language::LanguageRegistry;
21use paths::contexts_dir;
22use project::Project;
23use regex::Regex;
24use rpc::AnyProtoClient;
25use std::{
26 cmp::Reverse,
27 ffi::OsStr,
28 mem,
29 path::{Path, PathBuf},
30 sync::Arc,
31 time::Duration,
32};
33use util::{ResultExt, TryFutureExt};
34
35pub fn init(client: &AnyProtoClient) {
36 client.add_model_message_handler(ContextStore::handle_advertise_contexts);
37 client.add_model_request_handler(ContextStore::handle_open_context);
38 client.add_model_request_handler(ContextStore::handle_create_context);
39 client.add_model_message_handler(ContextStore::handle_update_context);
40 client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
41}
42
43#[derive(Clone)]
44pub struct RemoteContextMetadata {
45 pub id: ContextId,
46 pub summary: Option<String>,
47}
48
49pub struct ContextStore {
50 contexts: Vec<ContextHandle>,
51 contexts_metadata: Vec<SavedContextMetadata>,
52 context_server_manager: Model<ContextServerManager>,
53 context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
54 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
55 host_contexts: Vec<RemoteContextMetadata>,
56 fs: Arc<dyn Fs>,
57 languages: Arc<LanguageRegistry>,
58 slash_commands: Arc<SlashCommandWorkingSet>,
59 tools: Arc<ToolWorkingSet>,
60 telemetry: Arc<Telemetry>,
61 _watch_updates: Task<Option<()>>,
62 client: Arc<Client>,
63 project: Model<Project>,
64 project_is_shared: bool,
65 client_subscription: Option<client::Subscription>,
66 _project_subscriptions: Vec<gpui::Subscription>,
67 prompt_builder: Arc<PromptBuilder>,
68}
69
70pub enum ContextStoreEvent {
71 ContextCreated(ContextId),
72}
73
74impl EventEmitter<ContextStoreEvent> for ContextStore {}
75
76enum ContextHandle {
77 Weak(WeakModel<Context>),
78 Strong(Model<Context>),
79}
80
81impl ContextHandle {
82 fn upgrade(&self) -> Option<Model<Context>> {
83 match self {
84 ContextHandle::Weak(weak) => weak.upgrade(),
85 ContextHandle::Strong(strong) => Some(strong.clone()),
86 }
87 }
88
89 fn downgrade(&self) -> WeakModel<Context> {
90 match self {
91 ContextHandle::Weak(weak) => weak.clone(),
92 ContextHandle::Strong(strong) => strong.downgrade(),
93 }
94 }
95}
96
97impl ContextStore {
98 pub fn new(
99 project: Model<Project>,
100 prompt_builder: Arc<PromptBuilder>,
101 slash_commands: Arc<SlashCommandWorkingSet>,
102 tools: Arc<ToolWorkingSet>,
103 cx: &mut AppContext,
104 ) -> Task<Result<Model<Self>>> {
105 let fs = project.read(cx).fs().clone();
106 let languages = project.read(cx).languages().clone();
107 let telemetry = project.read(cx).client().telemetry().clone();
108 cx.spawn(|mut cx| async move {
109 const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
110 let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
111
112 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
113 let context_server_factory_registry =
114 ContextServerFactoryRegistry::default_global(cx);
115 let context_server_manager = cx.new_model(|cx| {
116 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
117 });
118 let mut this = Self {
119 contexts: Vec::new(),
120 contexts_metadata: Vec::new(),
121 context_server_manager,
122 context_server_slash_command_ids: HashMap::default(),
123 context_server_tool_ids: HashMap::default(),
124 host_contexts: Vec::new(),
125 fs,
126 languages,
127 slash_commands,
128 tools,
129 telemetry,
130 _watch_updates: cx.spawn(|this, mut cx| {
131 async move {
132 while events.next().await.is_some() {
133 this.update(&mut cx, |this, cx| this.reload(cx))?
134 .await
135 .log_err();
136 }
137 anyhow::Ok(())
138 }
139 .log_err()
140 }),
141 client_subscription: None,
142 _project_subscriptions: vec![
143 cx.observe(&project, Self::handle_project_changed),
144 cx.subscribe(&project, Self::handle_project_event),
145 ],
146 project_is_shared: false,
147 client: project.read(cx).client(),
148 project: project.clone(),
149 prompt_builder,
150 };
151 this.handle_project_changed(project.clone(), cx);
152 this.synchronize_contexts(cx);
153 this.register_context_server_handlers(cx);
154 this
155 })?;
156 this.update(&mut cx, |this, cx| this.reload(cx))?
157 .await
158 .log_err();
159
160 Ok(this)
161 })
162 }
163
164 async fn handle_advertise_contexts(
165 this: Model<Self>,
166 envelope: TypedEnvelope<proto::AdvertiseContexts>,
167 mut cx: AsyncAppContext,
168 ) -> Result<()> {
169 this.update(&mut cx, |this, cx| {
170 this.host_contexts = envelope
171 .payload
172 .contexts
173 .into_iter()
174 .map(|context| RemoteContextMetadata {
175 id: ContextId::from_proto(context.context_id),
176 summary: context.summary,
177 })
178 .collect();
179 cx.notify();
180 })
181 }
182
183 async fn handle_open_context(
184 this: Model<Self>,
185 envelope: TypedEnvelope<proto::OpenContext>,
186 mut cx: AsyncAppContext,
187 ) -> Result<proto::OpenContextResponse> {
188 let context_id = ContextId::from_proto(envelope.payload.context_id);
189 let operations = this.update(&mut cx, |this, cx| {
190 if this.project.read(cx).is_via_collab() {
191 return Err(anyhow!("only the host contexts can be opened"));
192 }
193
194 let context = this
195 .loaded_context_for_id(&context_id, cx)
196 .context("context not found")?;
197 if context.read(cx).replica_id() != ReplicaId::default() {
198 return Err(anyhow!("context must be opened via the host"));
199 }
200
201 anyhow::Ok(
202 context
203 .read(cx)
204 .serialize_ops(&ContextVersion::default(), cx),
205 )
206 })??;
207 let operations = operations.await;
208 Ok(proto::OpenContextResponse {
209 context: Some(proto::Context { operations }),
210 })
211 }
212
213 async fn handle_create_context(
214 this: Model<Self>,
215 _: TypedEnvelope<proto::CreateContext>,
216 mut cx: AsyncAppContext,
217 ) -> Result<proto::CreateContextResponse> {
218 let (context_id, operations) = this.update(&mut cx, |this, cx| {
219 if this.project.read(cx).is_via_collab() {
220 return Err(anyhow!("can only create contexts as the host"));
221 }
222
223 let context = this.create(cx);
224 let context_id = context.read(cx).id().clone();
225 cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
226
227 anyhow::Ok((
228 context_id,
229 context
230 .read(cx)
231 .serialize_ops(&ContextVersion::default(), cx),
232 ))
233 })??;
234 let operations = operations.await;
235 Ok(proto::CreateContextResponse {
236 context_id: context_id.to_proto(),
237 context: Some(proto::Context { operations }),
238 })
239 }
240
241 async fn handle_update_context(
242 this: Model<Self>,
243 envelope: TypedEnvelope<proto::UpdateContext>,
244 mut cx: AsyncAppContext,
245 ) -> Result<()> {
246 this.update(&mut cx, |this, cx| {
247 let context_id = ContextId::from_proto(envelope.payload.context_id);
248 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
249 let operation_proto = envelope.payload.operation.context("invalid operation")?;
250 let operation = ContextOperation::from_proto(operation_proto)?;
251 context.update(cx, |context, cx| context.apply_ops([operation], cx));
252 }
253 Ok(())
254 })?
255 }
256
257 async fn handle_synchronize_contexts(
258 this: Model<Self>,
259 envelope: TypedEnvelope<proto::SynchronizeContexts>,
260 mut cx: AsyncAppContext,
261 ) -> Result<proto::SynchronizeContextsResponse> {
262 this.update(&mut cx, |this, cx| {
263 if this.project.read(cx).is_via_collab() {
264 return Err(anyhow!("only the host can synchronize contexts"));
265 }
266
267 let mut local_versions = Vec::new();
268 for remote_version_proto in envelope.payload.contexts {
269 let remote_version = ContextVersion::from_proto(&remote_version_proto);
270 let context_id = ContextId::from_proto(remote_version_proto.context_id);
271 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
272 let context = context.read(cx);
273 let operations = context.serialize_ops(&remote_version, cx);
274 local_versions.push(context.version(cx).to_proto(context_id.clone()));
275 let client = this.client.clone();
276 let project_id = envelope.payload.project_id;
277 cx.background_executor()
278 .spawn(async move {
279 let operations = operations.await;
280 for operation in operations {
281 client.send(proto::UpdateContext {
282 project_id,
283 context_id: context_id.to_proto(),
284 operation: Some(operation),
285 })?;
286 }
287 anyhow::Ok(())
288 })
289 .detach_and_log_err(cx);
290 }
291 }
292
293 this.advertise_contexts(cx);
294
295 anyhow::Ok(proto::SynchronizeContextsResponse {
296 contexts: local_versions,
297 })
298 })?
299 }
300
301 fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
302 let is_shared = self.project.read(cx).is_shared();
303 let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
304 if is_shared == was_shared {
305 return;
306 }
307
308 if is_shared {
309 self.contexts.retain_mut(|context| {
310 if let Some(strong_context) = context.upgrade() {
311 *context = ContextHandle::Strong(strong_context);
312 true
313 } else {
314 false
315 }
316 });
317 let remote_id = self.project.read(cx).remote_id().unwrap();
318 self.client_subscription = self
319 .client
320 .subscribe_to_entity(remote_id)
321 .log_err()
322 .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
323 self.advertise_contexts(cx);
324 } else {
325 self.client_subscription = None;
326 }
327 }
328
329 fn handle_project_event(
330 &mut self,
331 _: Model<Project>,
332 event: &project::Event,
333 cx: &mut ModelContext<Self>,
334 ) {
335 match event {
336 project::Event::Reshared => {
337 self.advertise_contexts(cx);
338 }
339 project::Event::HostReshared | project::Event::Rejoined => {
340 self.synchronize_contexts(cx);
341 }
342 project::Event::DisconnectedFromHost => {
343 self.contexts.retain_mut(|context| {
344 if let Some(strong_context) = context.upgrade() {
345 *context = ContextHandle::Weak(context.downgrade());
346 strong_context.update(cx, |context, cx| {
347 if context.replica_id() != ReplicaId::default() {
348 context.set_capability(language::Capability::ReadOnly, cx);
349 }
350 });
351 true
352 } else {
353 false
354 }
355 });
356 self.host_contexts.clear();
357 cx.notify();
358 }
359 _ => {}
360 }
361 }
362
363 pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
364 let context = cx.new_model(|cx| {
365 Context::local(
366 self.languages.clone(),
367 Some(self.project.clone()),
368 Some(self.telemetry.clone()),
369 self.prompt_builder.clone(),
370 self.slash_commands.clone(),
371 self.tools.clone(),
372 cx,
373 )
374 });
375 self.register_context(&context, cx);
376 context
377 }
378
379 pub fn create_remote_context(
380 &mut self,
381 cx: &mut ModelContext<Self>,
382 ) -> Task<Result<Model<Context>>> {
383 let project = self.project.read(cx);
384 let Some(project_id) = project.remote_id() else {
385 return Task::ready(Err(anyhow!("project was not remote")));
386 };
387
388 let replica_id = project.replica_id();
389 let capability = project.capability();
390 let language_registry = self.languages.clone();
391 let project = self.project.clone();
392 let telemetry = self.telemetry.clone();
393 let prompt_builder = self.prompt_builder.clone();
394 let slash_commands = self.slash_commands.clone();
395 let tools = self.tools.clone();
396 let request = self.client.request(proto::CreateContext { project_id });
397 cx.spawn(|this, mut cx| async move {
398 let response = request.await?;
399 let context_id = ContextId::from_proto(response.context_id);
400 let context_proto = response.context.context("invalid context")?;
401 let context = cx.new_model(|cx| {
402 Context::new(
403 context_id.clone(),
404 replica_id,
405 capability,
406 language_registry,
407 prompt_builder,
408 slash_commands,
409 tools,
410 Some(project),
411 Some(telemetry),
412 cx,
413 )
414 })?;
415 let operations = cx
416 .background_executor()
417 .spawn(async move {
418 context_proto
419 .operations
420 .into_iter()
421 .map(ContextOperation::from_proto)
422 .collect::<Result<Vec<_>>>()
423 })
424 .await?;
425 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
426 this.update(&mut cx, |this, cx| {
427 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
428 existing_context
429 } else {
430 this.register_context(&context, cx);
431 this.synchronize_contexts(cx);
432 context
433 }
434 })
435 })
436 }
437
438 pub fn open_local_context(
439 &mut self,
440 path: PathBuf,
441 cx: &ModelContext<Self>,
442 ) -> Task<Result<Model<Context>>> {
443 if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
444 return Task::ready(Ok(existing_context));
445 }
446
447 let fs = self.fs.clone();
448 let languages = self.languages.clone();
449 let project = self.project.clone();
450 let telemetry = self.telemetry.clone();
451 let load = cx.background_executor().spawn({
452 let path = path.clone();
453 async move {
454 let saved_context = fs.load(&path).await?;
455 SavedContext::from_json(&saved_context)
456 }
457 });
458 let prompt_builder = self.prompt_builder.clone();
459 let slash_commands = self.slash_commands.clone();
460 let tools = self.tools.clone();
461
462 cx.spawn(|this, mut cx| async move {
463 let saved_context = load.await?;
464 let context = cx.new_model(|cx| {
465 Context::deserialize(
466 saved_context,
467 path.clone(),
468 languages,
469 prompt_builder,
470 slash_commands,
471 tools,
472 Some(project),
473 Some(telemetry),
474 cx,
475 )
476 })?;
477 this.update(&mut cx, |this, cx| {
478 if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
479 existing_context
480 } else {
481 this.register_context(&context, cx);
482 context
483 }
484 })
485 })
486 }
487
488 fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
489 self.contexts.iter().find_map(|context| {
490 let context = context.upgrade()?;
491 if context.read(cx).path() == Some(path) {
492 Some(context)
493 } else {
494 None
495 }
496 })
497 }
498
499 pub(super) fn loaded_context_for_id(
500 &self,
501 id: &ContextId,
502 cx: &AppContext,
503 ) -> Option<Model<Context>> {
504 self.contexts.iter().find_map(|context| {
505 let context = context.upgrade()?;
506 if context.read(cx).id() == id {
507 Some(context)
508 } else {
509 None
510 }
511 })
512 }
513
514 pub fn open_remote_context(
515 &mut self,
516 context_id: ContextId,
517 cx: &mut ModelContext<Self>,
518 ) -> Task<Result<Model<Context>>> {
519 let project = self.project.read(cx);
520 let Some(project_id) = project.remote_id() else {
521 return Task::ready(Err(anyhow!("project was not remote")));
522 };
523
524 if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
525 return Task::ready(Ok(context));
526 }
527
528 let replica_id = project.replica_id();
529 let capability = project.capability();
530 let language_registry = self.languages.clone();
531 let project = self.project.clone();
532 let telemetry = self.telemetry.clone();
533 let request = self.client.request(proto::OpenContext {
534 project_id,
535 context_id: context_id.to_proto(),
536 });
537 let prompt_builder = self.prompt_builder.clone();
538 let slash_commands = self.slash_commands.clone();
539 let tools = self.tools.clone();
540 cx.spawn(|this, mut cx| async move {
541 let response = request.await?;
542 let context_proto = response.context.context("invalid context")?;
543 let context = cx.new_model(|cx| {
544 Context::new(
545 context_id.clone(),
546 replica_id,
547 capability,
548 language_registry,
549 prompt_builder,
550 slash_commands,
551 tools,
552 Some(project),
553 Some(telemetry),
554 cx,
555 )
556 })?;
557 let operations = cx
558 .background_executor()
559 .spawn(async move {
560 context_proto
561 .operations
562 .into_iter()
563 .map(ContextOperation::from_proto)
564 .collect::<Result<Vec<_>>>()
565 })
566 .await?;
567 context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
568 this.update(&mut cx, |this, cx| {
569 if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
570 existing_context
571 } else {
572 this.register_context(&context, cx);
573 this.synchronize_contexts(cx);
574 context
575 }
576 })
577 })
578 }
579
580 fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
581 let handle = if self.project_is_shared {
582 ContextHandle::Strong(context.clone())
583 } else {
584 ContextHandle::Weak(context.downgrade())
585 };
586 self.contexts.push(handle);
587 self.advertise_contexts(cx);
588 cx.subscribe(context, Self::handle_context_event).detach();
589 }
590
591 fn handle_context_event(
592 &mut self,
593 context: Model<Context>,
594 event: &ContextEvent,
595 cx: &mut ModelContext<Self>,
596 ) {
597 let Some(project_id) = self.project.read(cx).remote_id() else {
598 return;
599 };
600
601 match event {
602 ContextEvent::SummaryChanged => {
603 self.advertise_contexts(cx);
604 }
605 ContextEvent::Operation(operation) => {
606 let context_id = context.read(cx).id().to_proto();
607 let operation = operation.to_proto();
608 self.client
609 .send(proto::UpdateContext {
610 project_id,
611 context_id,
612 operation: Some(operation),
613 })
614 .log_err();
615 }
616 _ => {}
617 }
618 }
619
620 fn advertise_contexts(&self, cx: &AppContext) {
621 let Some(project_id) = self.project.read(cx).remote_id() else {
622 return;
623 };
624
625 // For now, only the host can advertise their open contexts.
626 if self.project.read(cx).is_via_collab() {
627 return;
628 }
629
630 let contexts = self
631 .contexts
632 .iter()
633 .rev()
634 .filter_map(|context| {
635 let context = context.upgrade()?.read(cx);
636 if context.replica_id() == ReplicaId::default() {
637 Some(proto::ContextMetadata {
638 context_id: context.id().to_proto(),
639 summary: context.summary().map(|summary| summary.text.clone()),
640 })
641 } else {
642 None
643 }
644 })
645 .collect();
646 self.client
647 .send(proto::AdvertiseContexts {
648 project_id,
649 contexts,
650 })
651 .ok();
652 }
653
654 fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
655 let Some(project_id) = self.project.read(cx).remote_id() else {
656 return;
657 };
658
659 let contexts = self
660 .contexts
661 .iter()
662 .filter_map(|context| {
663 let context = context.upgrade()?.read(cx);
664 if context.replica_id() != ReplicaId::default() {
665 Some(context.version(cx).to_proto(context.id().clone()))
666 } else {
667 None
668 }
669 })
670 .collect();
671
672 let client = self.client.clone();
673 let request = self.client.request(proto::SynchronizeContexts {
674 project_id,
675 contexts,
676 });
677 cx.spawn(|this, cx| async move {
678 let response = request.await?;
679
680 let mut context_ids = Vec::new();
681 let mut operations = Vec::new();
682 this.read_with(&cx, |this, cx| {
683 for context_version_proto in response.contexts {
684 let context_version = ContextVersion::from_proto(&context_version_proto);
685 let context_id = ContextId::from_proto(context_version_proto.context_id);
686 if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
687 context_ids.push(context_id);
688 operations.push(context.read(cx).serialize_ops(&context_version, cx));
689 }
690 }
691 })?;
692
693 let operations = futures::future::join_all(operations).await;
694 for (context_id, operations) in context_ids.into_iter().zip(operations) {
695 for operation in operations {
696 client.send(proto::UpdateContext {
697 project_id,
698 context_id: context_id.to_proto(),
699 operation: Some(operation),
700 })?;
701 }
702 }
703
704 anyhow::Ok(())
705 })
706 .detach_and_log_err(cx);
707 }
708
709 pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
710 let metadata = self.contexts_metadata.clone();
711 let executor = cx.background_executor().clone();
712 cx.background_executor().spawn(async move {
713 if query.is_empty() {
714 metadata
715 } else {
716 let candidates = metadata
717 .iter()
718 .enumerate()
719 .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
720 .collect::<Vec<_>>();
721 let matches = fuzzy::match_strings(
722 &candidates,
723 &query,
724 false,
725 100,
726 &Default::default(),
727 executor,
728 )
729 .await;
730
731 matches
732 .into_iter()
733 .map(|mat| metadata[mat.candidate_id].clone())
734 .collect()
735 }
736 })
737 }
738
739 pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
740 &self.host_contexts
741 }
742
743 fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
744 let fs = self.fs.clone();
745 cx.spawn(|this, mut cx| async move {
746 fs.create_dir(contexts_dir()).await?;
747
748 let mut paths = fs.read_dir(contexts_dir()).await?;
749 let mut contexts = Vec::<SavedContextMetadata>::new();
750 while let Some(path) = paths.next().await {
751 let path = path?;
752 if path.extension() != Some(OsStr::new("json")) {
753 continue;
754 }
755
756 let pattern = r" - \d+.zed.json$";
757 let re = Regex::new(pattern).unwrap();
758
759 let metadata = fs.metadata(&path).await?;
760 if let Some((file_name, metadata)) = path
761 .file_name()
762 .and_then(|name| name.to_str())
763 .zip(metadata)
764 {
765 // This is used to filter out contexts saved by the new assistant.
766 if !re.is_match(file_name) {
767 continue;
768 }
769
770 if let Some(title) = re.replace(file_name, "").lines().next() {
771 contexts.push(SavedContextMetadata {
772 title: title.to_string(),
773 path,
774 mtime: metadata.mtime.timestamp_for_user().into(),
775 });
776 }
777 }
778 }
779 contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
780
781 this.update(&mut cx, |this, cx| {
782 this.contexts_metadata = contexts;
783 cx.notify();
784 })
785 })
786 }
787
788 pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
789 cx.update_model(
790 &self.context_server_manager,
791 |context_server_manager, cx| {
792 for server in context_server_manager.servers() {
793 context_server_manager
794 .restart_server(&server.id(), cx)
795 .detach_and_log_err(cx);
796 }
797 },
798 );
799 }
800
801 fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
802 cx.subscribe(
803 &self.context_server_manager.clone(),
804 Self::handle_context_server_event,
805 )
806 .detach();
807 }
808
809 fn handle_context_server_event(
810 &mut self,
811 context_server_manager: Model<ContextServerManager>,
812 event: &context_server::manager::Event,
813 cx: &mut ModelContext<Self>,
814 ) {
815 let slash_command_working_set = self.slash_commands.clone();
816 let tool_working_set = self.tools.clone();
817 match event {
818 context_server::manager::Event::ServerStarted { server_id } => {
819 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
820 let context_server_manager = context_server_manager.clone();
821 cx.spawn({
822 let server = server.clone();
823 let server_id = server_id.clone();
824 |this, mut cx| async move {
825 let Some(protocol) = server.client() else {
826 return;
827 };
828
829 if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
830 if let Some(prompts) = protocol.list_prompts().await.log_err() {
831 let slash_command_ids = prompts
832 .into_iter()
833 .filter(context_server_command::acceptable_prompt)
834 .map(|prompt| {
835 log::info!(
836 "registering context server command: {:?}",
837 prompt.name
838 );
839 slash_command_working_set.insert(Arc::new(
840 context_server_command::ContextServerSlashCommand::new(
841 context_server_manager.clone(),
842 &server,
843 prompt,
844 ),
845 ))
846 })
847 .collect::<Vec<_>>();
848
849 this.update(&mut cx, |this, _cx| {
850 this.context_server_slash_command_ids
851 .insert(server_id.clone(), slash_command_ids);
852 })
853 .log_err();
854 }
855 }
856
857 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
858 if let Some(tools) = protocol.list_tools().await.log_err() {
859 let tool_ids = tools.tools.into_iter().map(|tool| {
860 log::info!("registering context server tool: {:?}", tool.name);
861 tool_working_set.insert(
862 Arc::new(ContextServerTool::new(
863 context_server_manager.clone(),
864 server.id(),
865 tool,
866 )),
867 )
868
869 }).collect::<Vec<_>>();
870
871
872 this.update(&mut cx, |this, _cx| {
873 this.context_server_tool_ids
874 .insert(server_id, tool_ids);
875 })
876 .log_err();
877 }
878 }
879 }
880 })
881 .detach();
882 }
883 }
884 context_server::manager::Event::ServerStopped { server_id } => {
885 if let Some(slash_command_ids) =
886 self.context_server_slash_command_ids.remove(server_id)
887 {
888 slash_command_working_set.remove(&slash_command_ids);
889 }
890
891 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
892 tool_working_set.remove(&tool_ids);
893 }
894 }
895 }
896 }
897}