1use anyhow::{Context, Result};
2use collections::HashMap;
3use futures::{
4 Future, FutureExt as _,
5 channel::oneshot,
6 future::{BoxFuture, LocalBoxFuture},
7};
8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
9use parking_lot::Mutex;
10use proto::{
11 AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
12 RequestMessage, TypedEnvelope, error::ErrorExt as _,
13};
14use std::{
15 any::{Any, TypeId},
16 sync::{
17 Arc, OnceLock,
18 atomic::{self, AtomicU64},
19 },
20 time::Duration,
21};
22
23#[derive(Debug, Clone)]
24pub struct AnyProtoClient(Arc<State>);
25
26type RequestIds = Arc<
27 Mutex<
28 HashMap<
29 LspRequestId,
30 oneshot::Sender<
31 Result<
32 Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
33 >,
34 >,
35 >,
36 >,
37>;
38
39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
41
42struct State {
43 client: Arc<dyn ProtoClient>,
44 next_lsp_request_id: Arc<AtomicU64>,
45 request_ids: RequestIds,
46}
47
48impl std::fmt::Debug for State {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("State")
51 .field("next_lsp_request_id", &self.next_lsp_request_id)
52 .field("request_ids", &self.request_ids)
53 .finish_non_exhaustive()
54 }
55}
56
57pub trait ProtoClient: Send + Sync {
58 fn request(
59 &self,
60 envelope: Envelope,
61 request_type: &'static str,
62 ) -> BoxFuture<'static, Result<Envelope>>;
63
64 fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
65
66 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
67
68 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
69
70 fn is_via_collab(&self) -> bool;
71 fn has_wsl_interop(&self) -> bool;
72}
73
74#[derive(Default)]
75pub struct ProtoMessageHandlerSet {
76 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
77 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
78 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
79 pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
80 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
81}
82
83pub type ProtoMessageHandler = Arc<
84 dyn Send
85 + Sync
86 + Fn(
87 AnyEntity,
88 Box<dyn AnyTypedEnvelope>,
89 AnyProtoClient,
90 AsyncApp,
91 ) -> LocalBoxFuture<'static, Result<()>>,
92>;
93
94impl ProtoMessageHandlerSet {
95 pub fn clear(&mut self) {
96 self.message_handlers.clear();
97 self.entities_by_message_type.clear();
98 self.entities_by_type_and_remote_id.clear();
99 self.entity_id_extractors.clear();
100 }
101
102 fn add_message_handler(
103 &mut self,
104 message_type_id: TypeId,
105 entity: gpui::AnyWeakEntity,
106 handler: ProtoMessageHandler,
107 ) {
108 self.entities_by_message_type
109 .insert(message_type_id, entity);
110 let prev_handler = self.message_handlers.insert(message_type_id, handler);
111 if prev_handler.is_some() {
112 panic!("registered handler for the same message twice");
113 }
114 }
115
116 fn add_entity_message_handler(
117 &mut self,
118 message_type_id: TypeId,
119 entity_type_id: TypeId,
120 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
121 handler: ProtoMessageHandler,
122 ) {
123 self.entity_id_extractors
124 .entry(message_type_id)
125 .or_insert(entity_id_extractor);
126 self.entity_types_by_message_type
127 .insert(message_type_id, entity_type_id);
128 let prev_handler = self.message_handlers.insert(message_type_id, handler);
129 if prev_handler.is_some() {
130 panic!("registered handler for the same message twice");
131 }
132 }
133
134 pub fn handle_message(
135 this: &parking_lot::Mutex<Self>,
136 message: Box<dyn AnyTypedEnvelope>,
137 client: AnyProtoClient,
138 cx: AsyncApp,
139 ) -> Option<LocalBoxFuture<'static, Result<()>>> {
140 let payload_type_id = message.payload_type_id();
141 let mut this = this.lock();
142 let handler = this.message_handlers.get(&payload_type_id)?.clone();
143 let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
144 entity.upgrade()?
145 } else {
146 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
147 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
148 let entity_id = (extract_entity_id)(message.as_ref());
149 match this
150 .entities_by_type_and_remote_id
151 .get_mut(&(entity_type_id, entity_id))?
152 {
153 EntityMessageSubscriber::Pending(pending) => {
154 pending.push(message);
155 return None;
156 }
157 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
158 }
159 };
160 drop(this);
161 Some(handler(entity, message, client, cx))
162 }
163}
164
165pub enum EntityMessageSubscriber {
166 Entity { handle: AnyWeakEntity },
167 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
168}
169
170impl std::fmt::Debug for EntityMessageSubscriber {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 match self {
173 EntityMessageSubscriber::Entity { handle } => f
174 .debug_struct("EntityMessageSubscriber::Entity")
175 .field("handle", handle)
176 .finish(),
177 EntityMessageSubscriber::Pending(vec) => f
178 .debug_struct("EntityMessageSubscriber::Pending")
179 .field(
180 "envelopes",
181 &vec.iter()
182 .map(|envelope| envelope.payload_type_name())
183 .collect::<Vec<_>>(),
184 )
185 .finish(),
186 }
187 }
188}
189
190impl<T> From<Arc<T>> for AnyProtoClient
191where
192 T: ProtoClient + 'static,
193{
194 fn from(client: Arc<T>) -> Self {
195 Self::new(client)
196 }
197}
198
199impl AnyProtoClient {
200 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
201 Self(Arc::new(State {
202 client,
203 next_lsp_request_id: NEXT_LSP_REQUEST_ID
204 .get_or_init(|| Arc::new(AtomicU64::new(0)))
205 .clone(),
206 request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
207 }))
208 }
209
210 pub fn is_via_collab(&self) -> bool {
211 self.0.client.is_via_collab()
212 }
213
214 pub fn request<T: RequestMessage>(
215 &self,
216 request: T,
217 ) -> impl Future<Output = Result<T::Response>> + use<T> {
218 let envelope = request.into_envelope(0, None, None);
219 let response = self.0.client.request(envelope, T::NAME);
220 async move {
221 T::Response::from_envelope(response.await?)
222 .context("received response of the wrong type")
223 }
224 }
225
226 pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
227 let envelope = request.into_envelope(0, None, None);
228 self.0.client.send(envelope, T::NAME)
229 }
230
231 pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
232 let envelope = request.into_envelope(0, Some(request_id), None);
233 self.0.client.send(envelope, T::NAME)
234 }
235
236 pub fn request_lsp<T>(
237 &self,
238 project_id: u64,
239 server_id: Option<u64>,
240 timeout: Duration,
241 executor: BackgroundExecutor,
242 request: T,
243 ) -> impl Future<
244 Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
245 > + use<T>
246 where
247 T: LspRequestMessage,
248 {
249 let new_id = LspRequestId(
250 self.0
251 .next_lsp_request_id
252 .fetch_add(1, atomic::Ordering::Acquire),
253 );
254 let (tx, rx) = oneshot::channel();
255 {
256 self.0.request_ids.lock().insert(new_id, tx);
257 }
258
259 let query = proto::LspQuery {
260 project_id,
261 server_id,
262 lsp_request_id: new_id.0,
263 request: Some(request.to_proto_query()),
264 };
265 let request = self.request(query);
266 let request_ids = self.0.request_ids.clone();
267 async move {
268 match request.await {
269 Ok(_request_enqueued) => {}
270 Err(e) => {
271 request_ids.lock().remove(&new_id);
272 return Err(e).context("sending LSP proto request");
273 }
274 }
275
276 let response = rx.with_timeout(timeout, &executor).await;
277 {
278 request_ids.lock().remove(&new_id);
279 }
280 match response {
281 Ok(Ok(response)) => {
282 let response = response
283 .context("waiting for LSP proto response")?
284 .map(|response| {
285 anyhow::Ok(TypedEnvelope {
286 payload: response
287 .payload
288 .into_iter()
289 .map(|lsp_response| lsp_response.into_response::<T>())
290 .collect::<Result<Vec<_>>>()?,
291 sender_id: response.sender_id,
292 original_sender_id: response.original_sender_id,
293 message_id: response.message_id,
294 received_at: response.received_at,
295 })
296 })
297 .transpose()
298 .context("converting LSP proto response")?;
299 Ok(response)
300 }
301 Err(_cancelled_due_timeout) => Ok(None),
302 Ok(Err(_channel_dropped)) => Ok(None),
303 }
304 }
305 }
306
307 pub fn send_lsp_response<T: LspRequestMessage>(
308 &self,
309 project_id: u64,
310 lsp_request_id: LspRequestId,
311 server_responses: HashMap<u64, T::Response>,
312 ) -> Result<()> {
313 self.send(proto::LspQueryResponse {
314 project_id,
315 lsp_request_id: lsp_request_id.0,
316 responses: server_responses
317 .into_iter()
318 .map(|(server_id, response)| proto::LspResponse {
319 server_id,
320 response: Some(T::response_to_proto_query(response)),
321 })
322 .collect(),
323 })
324 }
325
326 pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
327 let request_id = LspRequestId(envelope.payload.lsp_request_id);
328 let mut response_senders = self.0.request_ids.lock();
329 if let Some(tx) = response_senders.remove(&request_id) {
330 let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
331 tx.send(Ok(Some(proto::TypedEnvelope {
332 sender_id: envelope.sender_id,
333 original_sender_id: envelope.original_sender_id,
334 message_id: envelope.message_id,
335 received_at: envelope.received_at,
336 payload: responses
337 .into_iter()
338 .filter_map(|response| {
339 use proto::lsp_response::Response;
340
341 let server_id = response.server_id;
342 let response = match response.response? {
343 Response::GetReferencesResponse(response) => {
344 to_any_envelope(&envelope, response)
345 }
346 Response::GetDocumentColorResponse(response) => {
347 to_any_envelope(&envelope, response)
348 }
349 Response::GetHoverResponse(response) => {
350 to_any_envelope(&envelope, response)
351 }
352 Response::GetCodeActionsResponse(response) => {
353 to_any_envelope(&envelope, response)
354 }
355 Response::GetSignatureHelpResponse(response) => {
356 to_any_envelope(&envelope, response)
357 }
358 Response::GetCodeLensResponse(response) => {
359 to_any_envelope(&envelope, response)
360 }
361 Response::GetDocumentDiagnosticsResponse(response) => {
362 to_any_envelope(&envelope, response)
363 }
364 Response::GetDefinitionResponse(response) => {
365 to_any_envelope(&envelope, response)
366 }
367 Response::GetDeclarationResponse(response) => {
368 to_any_envelope(&envelope, response)
369 }
370 Response::GetTypeDefinitionResponse(response) => {
371 to_any_envelope(&envelope, response)
372 }
373 Response::GetImplementationResponse(response) => {
374 to_any_envelope(&envelope, response)
375 }
376 Response::InlayHintsResponse(response) => {
377 to_any_envelope(&envelope, response)
378 }
379 Response::SemanticTokensResponse(response) => {
380 to_any_envelope(&envelope, response)
381 }
382 Response::GetFoldingRangesResponse(response) => {
383 to_any_envelope(&envelope, response)
384 }
385 Response::GetDocumentSymbolsResponse(response) => {
386 to_any_envelope(&envelope, response)
387 }
388 };
389 Some(proto::ProtoLspResponse {
390 server_id,
391 response,
392 })
393 })
394 .collect(),
395 })))
396 .ok();
397 }
398 }
399
400 pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
401 where
402 M: RequestMessage,
403 E: 'static,
404 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
405 F: 'static + Future<Output = Result<M::Response>>,
406 {
407 self.0
408 .client
409 .message_handler_set()
410 .lock()
411 .add_message_handler(
412 TypeId::of::<M>(),
413 entity.into(),
414 Arc::new(move |entity, envelope, client, cx| {
415 let entity = entity.downcast::<E>().unwrap();
416 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
417 let request_id = envelope.message_id();
418 handler(entity, *envelope, cx)
419 .then(move |result| async move {
420 match result {
421 Ok(response) => {
422 client.send_response(request_id, response)?;
423 Ok(())
424 }
425 Err(error) => {
426 client.send_response(request_id, error.to_proto())?;
427 Err(error)
428 }
429 }
430 })
431 .boxed_local()
432 }),
433 )
434 }
435
436 pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
437 where
438 M: EnvelopedMessage + RequestMessage + EntityMessage,
439 E: 'static,
440 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
441 F: 'static + Future<Output = Result<M::Response>>,
442 {
443 let message_type_id = TypeId::of::<M>();
444 let entity_type_id = TypeId::of::<E>();
445 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
446 (envelope as &dyn Any)
447 .downcast_ref::<TypedEnvelope<M>>()
448 .unwrap()
449 .payload
450 .remote_entity_id()
451 };
452 self.0
453 .client
454 .message_handler_set()
455 .lock()
456 .add_entity_message_handler(
457 message_type_id,
458 entity_type_id,
459 entity_id_extractor,
460 Arc::new(move |entity, envelope, client, cx| {
461 let entity = entity.downcast::<E>().unwrap();
462 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
463 let request_id = envelope.message_id();
464 handler(entity, *envelope, cx)
465 .then(move |result| async move {
466 match result {
467 Ok(response) => {
468 client.send_response(request_id, response)?;
469 Ok(())
470 }
471 Err(error) => {
472 client.send_response(request_id, error.to_proto())?;
473 Err(error)
474 }
475 }
476 })
477 .boxed_local()
478 }),
479 );
480 }
481
482 pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
483 where
484 M: EnvelopedMessage + EntityMessage,
485 E: 'static,
486 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
487 F: 'static + Future<Output = Result<()>>,
488 {
489 let message_type_id = TypeId::of::<M>();
490 let entity_type_id = TypeId::of::<E>();
491 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
492 (envelope as &dyn Any)
493 .downcast_ref::<TypedEnvelope<M>>()
494 .unwrap()
495 .payload
496 .remote_entity_id()
497 };
498 self.0
499 .client
500 .message_handler_set()
501 .lock()
502 .add_entity_message_handler(
503 message_type_id,
504 entity_type_id,
505 entity_id_extractor,
506 Arc::new(move |entity, envelope, _, cx| {
507 let entity = entity.downcast::<E>().unwrap();
508 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
509 handler(entity, *envelope, cx).boxed_local()
510 }),
511 );
512 }
513
514 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
515 let id = (TypeId::of::<E>(), remote_id);
516
517 let mut message_handlers = self.0.client.message_handler_set().lock();
518 if message_handlers
519 .entities_by_type_and_remote_id
520 .contains_key(&id)
521 {
522 panic!("already subscribed to entity");
523 }
524
525 message_handlers.entities_by_type_and_remote_id.insert(
526 id,
527 EntityMessageSubscriber::Entity {
528 handle: entity.downgrade().into(),
529 },
530 );
531 }
532
533 pub fn has_wsl_interop(&self) -> bool {
534 self.0.client.has_wsl_interop()
535 }
536}
537
538fn to_any_envelope<T: EnvelopedMessage>(
539 envelope: &TypedEnvelope<proto::LspQueryResponse>,
540 response: T,
541) -> Box<dyn AnyTypedEnvelope> {
542 Box::new(proto::TypedEnvelope {
543 sender_id: envelope.sender_id,
544 original_sender_id: envelope.original_sender_id,
545 message_id: envelope.message_id,
546 received_at: envelope.received_at,
547 payload: response,
548 }) as Box<_>
549}