1use anyhow::{Context, Result};
2use collections::HashMap;
3use futures::{
4 Future, FutureExt as _,
5 channel::oneshot,
6 future::{BoxFuture, LocalBoxFuture},
7};
8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
9use parking_lot::Mutex;
10use proto::{
11 AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
12 RequestMessage, TypedEnvelope, error::ErrorExt as _,
13};
14use std::{
15 any::{Any, TypeId},
16 sync::{
17 Arc, OnceLock,
18 atomic::{self, AtomicU64},
19 },
20 time::Duration,
21};
22
23#[derive(Clone)]
24pub struct AnyProtoClient(Arc<State>);
25
26type RequestIds = Arc<
27 Mutex<
28 HashMap<
29 LspRequestId,
30 oneshot::Sender<
31 Result<
32 Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
33 >,
34 >,
35 >,
36 >,
37>;
38
39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
41
42struct State {
43 client: Arc<dyn ProtoClient>,
44 next_lsp_request_id: Arc<AtomicU64>,
45 request_ids: RequestIds,
46}
47
48pub trait ProtoClient: Send + Sync {
49 fn request(
50 &self,
51 envelope: Envelope,
52 request_type: &'static str,
53 ) -> BoxFuture<'static, Result<Envelope>>;
54
55 fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
56
57 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
58
59 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
60
61 fn is_via_collab(&self) -> bool;
62 fn has_wsl_interop(&self) -> bool;
63}
64
65#[derive(Default)]
66pub struct ProtoMessageHandlerSet {
67 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
68 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
69 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
70 pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
71 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
72}
73
74pub type ProtoMessageHandler = Arc<
75 dyn Send
76 + Sync
77 + Fn(
78 AnyEntity,
79 Box<dyn AnyTypedEnvelope>,
80 AnyProtoClient,
81 AsyncApp,
82 ) -> LocalBoxFuture<'static, Result<()>>,
83>;
84
85impl ProtoMessageHandlerSet {
86 pub fn clear(&mut self) {
87 self.message_handlers.clear();
88 self.entities_by_message_type.clear();
89 self.entities_by_type_and_remote_id.clear();
90 self.entity_id_extractors.clear();
91 }
92
93 fn add_message_handler(
94 &mut self,
95 message_type_id: TypeId,
96 entity: gpui::AnyWeakEntity,
97 handler: ProtoMessageHandler,
98 ) {
99 self.entities_by_message_type
100 .insert(message_type_id, entity);
101 let prev_handler = self.message_handlers.insert(message_type_id, handler);
102 if prev_handler.is_some() {
103 panic!("registered handler for the same message twice");
104 }
105 }
106
107 fn add_entity_message_handler(
108 &mut self,
109 message_type_id: TypeId,
110 entity_type_id: TypeId,
111 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
112 handler: ProtoMessageHandler,
113 ) {
114 self.entity_id_extractors
115 .entry(message_type_id)
116 .or_insert(entity_id_extractor);
117 self.entity_types_by_message_type
118 .insert(message_type_id, entity_type_id);
119 let prev_handler = self.message_handlers.insert(message_type_id, handler);
120 if prev_handler.is_some() {
121 panic!("registered handler for the same message twice");
122 }
123 }
124
125 pub fn handle_message(
126 this: &parking_lot::Mutex<Self>,
127 message: Box<dyn AnyTypedEnvelope>,
128 client: AnyProtoClient,
129 cx: AsyncApp,
130 ) -> Option<LocalBoxFuture<'static, Result<()>>> {
131 let payload_type_id = message.payload_type_id();
132 let mut this = this.lock();
133 let handler = this.message_handlers.get(&payload_type_id)?.clone();
134 let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
135 entity.upgrade()?
136 } else {
137 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
138 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
139 let entity_id = (extract_entity_id)(message.as_ref());
140 match this
141 .entities_by_type_and_remote_id
142 .get_mut(&(entity_type_id, entity_id))?
143 {
144 EntityMessageSubscriber::Pending(pending) => {
145 pending.push(message);
146 return None;
147 }
148 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
149 }
150 };
151 drop(this);
152 Some(handler(entity, message, client, cx))
153 }
154}
155
156pub enum EntityMessageSubscriber {
157 Entity { handle: AnyWeakEntity },
158 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
159}
160
161impl std::fmt::Debug for EntityMessageSubscriber {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 match self {
164 EntityMessageSubscriber::Entity { handle } => f
165 .debug_struct("EntityMessageSubscriber::Entity")
166 .field("handle", handle)
167 .finish(),
168 EntityMessageSubscriber::Pending(vec) => f
169 .debug_struct("EntityMessageSubscriber::Pending")
170 .field(
171 "envelopes",
172 &vec.iter()
173 .map(|envelope| envelope.payload_type_name())
174 .collect::<Vec<_>>(),
175 )
176 .finish(),
177 }
178 }
179}
180
181impl<T> From<Arc<T>> for AnyProtoClient
182where
183 T: ProtoClient + 'static,
184{
185 fn from(client: Arc<T>) -> Self {
186 Self::new(client)
187 }
188}
189
190impl AnyProtoClient {
191 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
192 Self(Arc::new(State {
193 client,
194 next_lsp_request_id: NEXT_LSP_REQUEST_ID
195 .get_or_init(|| Arc::new(AtomicU64::new(0)))
196 .clone(),
197 request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
198 }))
199 }
200
201 pub fn is_via_collab(&self) -> bool {
202 self.0.client.is_via_collab()
203 }
204
205 pub fn request<T: RequestMessage>(
206 &self,
207 request: T,
208 ) -> impl Future<Output = Result<T::Response>> + use<T> {
209 let envelope = request.into_envelope(0, None, None);
210 let response = self.0.client.request(envelope, T::NAME);
211 async move {
212 T::Response::from_envelope(response.await?)
213 .context("received response of the wrong type")
214 }
215 }
216
217 pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
218 let envelope = request.into_envelope(0, None, None);
219 self.0.client.send(envelope, T::NAME)
220 }
221
222 pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
223 let envelope = request.into_envelope(0, Some(request_id), None);
224 self.0.client.send(envelope, T::NAME)
225 }
226
227 pub fn request_lsp<T>(
228 &self,
229 project_id: u64,
230 server_id: Option<u64>,
231 timeout: Duration,
232 executor: BackgroundExecutor,
233 request: T,
234 ) -> impl Future<
235 Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
236 > + use<T>
237 where
238 T: LspRequestMessage,
239 {
240 let new_id = LspRequestId(
241 self.0
242 .next_lsp_request_id
243 .fetch_add(1, atomic::Ordering::Acquire),
244 );
245 let (tx, rx) = oneshot::channel();
246 {
247 self.0.request_ids.lock().insert(new_id, tx);
248 }
249
250 let query = proto::LspQuery {
251 project_id,
252 server_id,
253 lsp_request_id: new_id.0,
254 request: Some(request.to_proto_query()),
255 };
256 let request = self.request(query);
257 let request_ids = self.0.request_ids.clone();
258 async move {
259 match request.await {
260 Ok(_request_enqueued) => {}
261 Err(e) => {
262 request_ids.lock().remove(&new_id);
263 return Err(e).context("sending LSP proto request");
264 }
265 }
266
267 let response = rx.with_timeout(timeout, &executor).await;
268 {
269 request_ids.lock().remove(&new_id);
270 }
271 match response {
272 Ok(Ok(response)) => {
273 let response = response
274 .context("waiting for LSP proto response")?
275 .map(|response| {
276 anyhow::Ok(TypedEnvelope {
277 payload: response
278 .payload
279 .into_iter()
280 .map(|lsp_response| lsp_response.into_response::<T>())
281 .collect::<Result<Vec<_>>>()?,
282 sender_id: response.sender_id,
283 original_sender_id: response.original_sender_id,
284 message_id: response.message_id,
285 received_at: response.received_at,
286 })
287 })
288 .transpose()
289 .context("converting LSP proto response")?;
290 Ok(response)
291 }
292 Err(_cancelled_due_timeout) => Ok(None),
293 Ok(Err(_channel_dropped)) => Ok(None),
294 }
295 }
296 }
297
298 pub fn send_lsp_response<T: LspRequestMessage>(
299 &self,
300 project_id: u64,
301 lsp_request_id: LspRequestId,
302 server_responses: HashMap<u64, T::Response>,
303 ) -> Result<()> {
304 self.send(proto::LspQueryResponse {
305 project_id,
306 lsp_request_id: lsp_request_id.0,
307 responses: server_responses
308 .into_iter()
309 .map(|(server_id, response)| proto::LspResponse {
310 server_id,
311 response: Some(T::response_to_proto_query(response)),
312 })
313 .collect(),
314 })
315 }
316
317 pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
318 let request_id = LspRequestId(envelope.payload.lsp_request_id);
319 let mut response_senders = self.0.request_ids.lock();
320 if let Some(tx) = response_senders.remove(&request_id) {
321 let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
322 tx.send(Ok(Some(proto::TypedEnvelope {
323 sender_id: envelope.sender_id,
324 original_sender_id: envelope.original_sender_id,
325 message_id: envelope.message_id,
326 received_at: envelope.received_at,
327 payload: responses
328 .into_iter()
329 .filter_map(|response| {
330 use proto::lsp_response::Response;
331
332 let server_id = response.server_id;
333 let response = match response.response? {
334 Response::GetReferencesResponse(response) => {
335 to_any_envelope(&envelope, response)
336 }
337 Response::GetDocumentColorResponse(response) => {
338 to_any_envelope(&envelope, response)
339 }
340 Response::GetHoverResponse(response) => {
341 to_any_envelope(&envelope, response)
342 }
343 Response::GetCodeActionsResponse(response) => {
344 to_any_envelope(&envelope, response)
345 }
346 Response::GetSignatureHelpResponse(response) => {
347 to_any_envelope(&envelope, response)
348 }
349 Response::GetCodeLensResponse(response) => {
350 to_any_envelope(&envelope, response)
351 }
352 Response::GetDocumentDiagnosticsResponse(response) => {
353 to_any_envelope(&envelope, response)
354 }
355 Response::GetDefinitionResponse(response) => {
356 to_any_envelope(&envelope, response)
357 }
358 Response::GetDeclarationResponse(response) => {
359 to_any_envelope(&envelope, response)
360 }
361 Response::GetTypeDefinitionResponse(response) => {
362 to_any_envelope(&envelope, response)
363 }
364 Response::GetImplementationResponse(response) => {
365 to_any_envelope(&envelope, response)
366 }
367 Response::InlayHintsResponse(response) => {
368 to_any_envelope(&envelope, response)
369 }
370 };
371 Some(proto::ProtoLspResponse {
372 server_id,
373 response,
374 })
375 })
376 .collect(),
377 })))
378 .ok();
379 }
380 }
381
382 pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
383 where
384 M: RequestMessage,
385 E: 'static,
386 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
387 F: 'static + Future<Output = Result<M::Response>>,
388 {
389 self.0
390 .client
391 .message_handler_set()
392 .lock()
393 .add_message_handler(
394 TypeId::of::<M>(),
395 entity.into(),
396 Arc::new(move |entity, envelope, client, cx| {
397 let entity = entity.downcast::<E>().unwrap();
398 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
399 let request_id = envelope.message_id();
400 handler(entity, *envelope, cx)
401 .then(move |result| async move {
402 match result {
403 Ok(response) => {
404 client.send_response(request_id, response)?;
405 Ok(())
406 }
407 Err(error) => {
408 client.send_response(request_id, error.to_proto())?;
409 Err(error)
410 }
411 }
412 })
413 .boxed_local()
414 }),
415 )
416 }
417
418 pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
419 where
420 M: EnvelopedMessage + RequestMessage + EntityMessage,
421 E: 'static,
422 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
423 F: 'static + Future<Output = Result<M::Response>>,
424 {
425 let message_type_id = TypeId::of::<M>();
426 let entity_type_id = TypeId::of::<E>();
427 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
428 (envelope as &dyn Any)
429 .downcast_ref::<TypedEnvelope<M>>()
430 .unwrap()
431 .payload
432 .remote_entity_id()
433 };
434 self.0
435 .client
436 .message_handler_set()
437 .lock()
438 .add_entity_message_handler(
439 message_type_id,
440 entity_type_id,
441 entity_id_extractor,
442 Arc::new(move |entity, envelope, client, cx| {
443 let entity = entity.downcast::<E>().unwrap();
444 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
445 let request_id = envelope.message_id();
446 handler(entity, *envelope, cx)
447 .then(move |result| async move {
448 match result {
449 Ok(response) => {
450 client.send_response(request_id, response)?;
451 Ok(())
452 }
453 Err(error) => {
454 client.send_response(request_id, error.to_proto())?;
455 Err(error)
456 }
457 }
458 })
459 .boxed_local()
460 }),
461 );
462 }
463
464 pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
465 where
466 M: EnvelopedMessage + EntityMessage,
467 E: 'static,
468 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
469 F: 'static + Future<Output = Result<()>>,
470 {
471 let message_type_id = TypeId::of::<M>();
472 let entity_type_id = TypeId::of::<E>();
473 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
474 (envelope as &dyn Any)
475 .downcast_ref::<TypedEnvelope<M>>()
476 .unwrap()
477 .payload
478 .remote_entity_id()
479 };
480 self.0
481 .client
482 .message_handler_set()
483 .lock()
484 .add_entity_message_handler(
485 message_type_id,
486 entity_type_id,
487 entity_id_extractor,
488 Arc::new(move |entity, envelope, _, cx| {
489 let entity = entity.downcast::<E>().unwrap();
490 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
491 handler(entity, *envelope, cx).boxed_local()
492 }),
493 );
494 }
495
496 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
497 let id = (TypeId::of::<E>(), remote_id);
498
499 let mut message_handlers = self.0.client.message_handler_set().lock();
500 if message_handlers
501 .entities_by_type_and_remote_id
502 .contains_key(&id)
503 {
504 panic!("already subscribed to entity");
505 }
506
507 message_handlers.entities_by_type_and_remote_id.insert(
508 id,
509 EntityMessageSubscriber::Entity {
510 handle: entity.downgrade().into(),
511 },
512 );
513 }
514
515 pub fn has_wsl_interop(&self) -> bool {
516 self.0.client.has_wsl_interop()
517 }
518}
519
520fn to_any_envelope<T: EnvelopedMessage>(
521 envelope: &TypedEnvelope<proto::LspQueryResponse>,
522 response: T,
523) -> Box<dyn AnyTypedEnvelope> {
524 Box::new(proto::TypedEnvelope {
525 sender_id: envelope.sender_id,
526 original_sender_id: envelope.original_sender_id,
527 message_id: envelope.message_id,
528 received_at: envelope.received_at,
529 payload: response,
530 }) as Box<_>
531}