1use anyhow::{Context, Result};
2use collections::HashMap;
3use futures::{
4 Future, FutureExt as _,
5 channel::oneshot,
6 future::{BoxFuture, LocalBoxFuture},
7};
8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
9use parking_lot::Mutex;
10use proto::{
11 AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
12 RequestMessage, TypedEnvelope, error::ErrorExt as _,
13};
14use std::{
15 any::{Any, TypeId},
16 sync::{
17 Arc, OnceLock,
18 atomic::{self, AtomicU64},
19 },
20 time::Duration,
21};
22
23#[derive(Clone)]
24pub struct AnyProtoClient(Arc<State>);
25
26type RequestIds = Arc<
27 Mutex<
28 HashMap<
29 LspRequestId,
30 oneshot::Sender<
31 Result<
32 Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
33 >,
34 >,
35 >,
36 >,
37>;
38
39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
41
42struct State {
43 client: Arc<dyn ProtoClient>,
44 next_lsp_request_id: Arc<AtomicU64>,
45 request_ids: RequestIds,
46}
47
48pub trait ProtoClient: Send + Sync {
49 fn request(
50 &self,
51 envelope: Envelope,
52 request_type: &'static str,
53 ) -> BoxFuture<'static, Result<Envelope>>;
54
55 fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
56
57 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
58
59 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
60
61 fn is_via_collab(&self) -> bool;
62}
63
64#[derive(Default)]
65pub struct ProtoMessageHandlerSet {
66 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
67 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
68 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
69 pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
70 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
71}
72
73pub type ProtoMessageHandler = Arc<
74 dyn Send
75 + Sync
76 + Fn(
77 AnyEntity,
78 Box<dyn AnyTypedEnvelope>,
79 AnyProtoClient,
80 AsyncApp,
81 ) -> LocalBoxFuture<'static, Result<()>>,
82>;
83
84impl ProtoMessageHandlerSet {
85 pub fn clear(&mut self) {
86 self.message_handlers.clear();
87 self.entities_by_message_type.clear();
88 self.entities_by_type_and_remote_id.clear();
89 self.entity_id_extractors.clear();
90 }
91
92 fn add_message_handler(
93 &mut self,
94 message_type_id: TypeId,
95 entity: gpui::AnyWeakEntity,
96 handler: ProtoMessageHandler,
97 ) {
98 self.entities_by_message_type
99 .insert(message_type_id, entity);
100 let prev_handler = self.message_handlers.insert(message_type_id, handler);
101 if prev_handler.is_some() {
102 panic!("registered handler for the same message twice");
103 }
104 }
105
106 fn add_entity_message_handler(
107 &mut self,
108 message_type_id: TypeId,
109 entity_type_id: TypeId,
110 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
111 handler: ProtoMessageHandler,
112 ) {
113 self.entity_id_extractors
114 .entry(message_type_id)
115 .or_insert(entity_id_extractor);
116 self.entity_types_by_message_type
117 .insert(message_type_id, entity_type_id);
118 let prev_handler = self.message_handlers.insert(message_type_id, handler);
119 if prev_handler.is_some() {
120 panic!("registered handler for the same message twice");
121 }
122 }
123
124 pub fn handle_message(
125 this: &parking_lot::Mutex<Self>,
126 message: Box<dyn AnyTypedEnvelope>,
127 client: AnyProtoClient,
128 cx: AsyncApp,
129 ) -> Option<LocalBoxFuture<'static, Result<()>>> {
130 let payload_type_id = message.payload_type_id();
131 let mut this = this.lock();
132 let handler = this.message_handlers.get(&payload_type_id)?.clone();
133 let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
134 entity.upgrade()?
135 } else {
136 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
137 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
138 let entity_id = (extract_entity_id)(message.as_ref());
139 match this
140 .entities_by_type_and_remote_id
141 .get_mut(&(entity_type_id, entity_id))?
142 {
143 EntityMessageSubscriber::Pending(pending) => {
144 pending.push(message);
145 return None;
146 }
147 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
148 }
149 };
150 drop(this);
151 Some(handler(entity, message, client, cx))
152 }
153}
154
155pub enum EntityMessageSubscriber {
156 Entity { handle: AnyWeakEntity },
157 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
158}
159
160impl std::fmt::Debug for EntityMessageSubscriber {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 EntityMessageSubscriber::Entity { handle } => f
164 .debug_struct("EntityMessageSubscriber::Entity")
165 .field("handle", handle)
166 .finish(),
167 EntityMessageSubscriber::Pending(vec) => f
168 .debug_struct("EntityMessageSubscriber::Pending")
169 .field(
170 "envelopes",
171 &vec.iter()
172 .map(|envelope| envelope.payload_type_name())
173 .collect::<Vec<_>>(),
174 )
175 .finish(),
176 }
177 }
178}
179
180impl<T> From<Arc<T>> for AnyProtoClient
181where
182 T: ProtoClient + 'static,
183{
184 fn from(client: Arc<T>) -> Self {
185 Self::new(client)
186 }
187}
188
189impl AnyProtoClient {
190 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
191 Self(Arc::new(State {
192 client,
193 next_lsp_request_id: NEXT_LSP_REQUEST_ID
194 .get_or_init(|| Arc::new(AtomicU64::new(0)))
195 .clone(),
196 request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
197 }))
198 }
199
200 pub fn is_via_collab(&self) -> bool {
201 self.0.client.is_via_collab()
202 }
203
204 pub fn request<T: RequestMessage>(
205 &self,
206 request: T,
207 ) -> impl Future<Output = Result<T::Response>> + use<T> {
208 let envelope = request.into_envelope(0, None, None);
209 let response = self.0.client.request(envelope, T::NAME);
210 async move {
211 T::Response::from_envelope(response.await?)
212 .context("received response of the wrong type")
213 }
214 }
215
216 pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
217 let envelope = request.into_envelope(0, None, None);
218 self.0.client.send(envelope, T::NAME)
219 }
220
221 pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
222 let envelope = request.into_envelope(0, Some(request_id), None);
223 self.0.client.send(envelope, T::NAME)
224 }
225
226 pub fn request_lsp<T>(
227 &self,
228 project_id: u64,
229 server_id: Option<u64>,
230 timeout: Duration,
231 executor: BackgroundExecutor,
232 request: T,
233 ) -> impl Future<
234 Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
235 > + use<T>
236 where
237 T: LspRequestMessage,
238 {
239 let new_id = LspRequestId(
240 self.0
241 .next_lsp_request_id
242 .fetch_add(1, atomic::Ordering::Acquire),
243 );
244 let (tx, rx) = oneshot::channel();
245 {
246 self.0.request_ids.lock().insert(new_id, tx);
247 }
248
249 let query = proto::LspQuery {
250 project_id,
251 server_id,
252 lsp_request_id: new_id.0,
253 request: Some(request.to_proto_query()),
254 };
255 let request = self.request(query);
256 let request_ids = self.0.request_ids.clone();
257 async move {
258 match request.await {
259 Ok(_request_enqueued) => {}
260 Err(e) => {
261 request_ids.lock().remove(&new_id);
262 return Err(e).context("sending LSP proto request");
263 }
264 }
265
266 let response = rx.with_timeout(timeout, &executor).await;
267 {
268 request_ids.lock().remove(&new_id);
269 }
270 match response {
271 Ok(Ok(response)) => {
272 let response = response
273 .context("waiting for LSP proto response")?
274 .map(|response| {
275 anyhow::Ok(TypedEnvelope {
276 payload: response
277 .payload
278 .into_iter()
279 .map(|lsp_response| lsp_response.into_response::<T>())
280 .collect::<Result<Vec<_>>>()?,
281 sender_id: response.sender_id,
282 original_sender_id: response.original_sender_id,
283 message_id: response.message_id,
284 received_at: response.received_at,
285 })
286 })
287 .transpose()
288 .context("converting LSP proto response")?;
289 Ok(response)
290 }
291 Err(_cancelled_due_timeout) => Ok(None),
292 Ok(Err(_channel_dropped)) => Ok(None),
293 }
294 }
295 }
296
297 pub fn send_lsp_response<T: LspRequestMessage>(
298 &self,
299 project_id: u64,
300 lsp_request_id: LspRequestId,
301 server_responses: HashMap<u64, T::Response>,
302 ) -> Result<()> {
303 self.send(proto::LspQueryResponse {
304 project_id,
305 lsp_request_id: lsp_request_id.0,
306 responses: server_responses
307 .into_iter()
308 .map(|(server_id, response)| proto::LspResponse {
309 server_id,
310 response: Some(T::response_to_proto_query(response)),
311 })
312 .collect(),
313 })
314 }
315
316 pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
317 let request_id = LspRequestId(envelope.payload.lsp_request_id);
318 let mut response_senders = self.0.request_ids.lock();
319 if let Some(tx) = response_senders.remove(&request_id) {
320 let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
321 tx.send(Ok(Some(proto::TypedEnvelope {
322 sender_id: envelope.sender_id,
323 original_sender_id: envelope.original_sender_id,
324 message_id: envelope.message_id,
325 received_at: envelope.received_at,
326 payload: responses
327 .into_iter()
328 .filter_map(|response| {
329 use proto::lsp_response::Response;
330
331 let server_id = response.server_id;
332 let response = match response.response? {
333 Response::GetReferencesResponse(response) => {
334 to_any_envelope(&envelope, response)
335 }
336 Response::GetDocumentColorResponse(response) => {
337 to_any_envelope(&envelope, response)
338 }
339 Response::GetHoverResponse(response) => {
340 to_any_envelope(&envelope, response)
341 }
342 Response::GetCodeActionsResponse(response) => {
343 to_any_envelope(&envelope, response)
344 }
345 Response::GetSignatureHelpResponse(response) => {
346 to_any_envelope(&envelope, response)
347 }
348 Response::GetCodeLensResponse(response) => {
349 to_any_envelope(&envelope, response)
350 }
351 Response::GetDocumentDiagnosticsResponse(response) => {
352 to_any_envelope(&envelope, response)
353 }
354 Response::GetDefinitionResponse(response) => {
355 to_any_envelope(&envelope, response)
356 }
357 Response::GetDeclarationResponse(response) => {
358 to_any_envelope(&envelope, response)
359 }
360 Response::GetTypeDefinitionResponse(response) => {
361 to_any_envelope(&envelope, response)
362 }
363 Response::GetImplementationResponse(response) => {
364 to_any_envelope(&envelope, response)
365 }
366 Response::InlayHintsResponse(response) => {
367 to_any_envelope(&envelope, response)
368 }
369 };
370 Some(proto::ProtoLspResponse {
371 server_id,
372 response,
373 })
374 })
375 .collect(),
376 })))
377 .ok();
378 }
379 }
380
381 pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
382 where
383 M: RequestMessage,
384 E: 'static,
385 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
386 F: 'static + Future<Output = Result<M::Response>>,
387 {
388 self.0
389 .client
390 .message_handler_set()
391 .lock()
392 .add_message_handler(
393 TypeId::of::<M>(),
394 entity.into(),
395 Arc::new(move |entity, envelope, client, cx| {
396 let entity = entity.downcast::<E>().unwrap();
397 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
398 let request_id = envelope.message_id();
399 handler(entity, *envelope, cx)
400 .then(move |result| async move {
401 match result {
402 Ok(response) => {
403 client.send_response(request_id, response)?;
404 Ok(())
405 }
406 Err(error) => {
407 client.send_response(request_id, error.to_proto())?;
408 Err(error)
409 }
410 }
411 })
412 .boxed_local()
413 }),
414 )
415 }
416
417 pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
418 where
419 M: EnvelopedMessage + RequestMessage + EntityMessage,
420 E: 'static,
421 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
422 F: 'static + Future<Output = Result<M::Response>>,
423 {
424 let message_type_id = TypeId::of::<M>();
425 let entity_type_id = TypeId::of::<E>();
426 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
427 (envelope as &dyn Any)
428 .downcast_ref::<TypedEnvelope<M>>()
429 .unwrap()
430 .payload
431 .remote_entity_id()
432 };
433 self.0
434 .client
435 .message_handler_set()
436 .lock()
437 .add_entity_message_handler(
438 message_type_id,
439 entity_type_id,
440 entity_id_extractor,
441 Arc::new(move |entity, envelope, client, cx| {
442 let entity = entity.downcast::<E>().unwrap();
443 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
444 let request_id = envelope.message_id();
445 handler(entity, *envelope, cx)
446 .then(move |result| async move {
447 match result {
448 Ok(response) => {
449 client.send_response(request_id, response)?;
450 Ok(())
451 }
452 Err(error) => {
453 client.send_response(request_id, error.to_proto())?;
454 Err(error)
455 }
456 }
457 })
458 .boxed_local()
459 }),
460 );
461 }
462
463 pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
464 where
465 M: EnvelopedMessage + EntityMessage,
466 E: 'static,
467 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
468 F: 'static + Future<Output = Result<()>>,
469 {
470 let message_type_id = TypeId::of::<M>();
471 let entity_type_id = TypeId::of::<E>();
472 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
473 (envelope as &dyn Any)
474 .downcast_ref::<TypedEnvelope<M>>()
475 .unwrap()
476 .payload
477 .remote_entity_id()
478 };
479 self.0
480 .client
481 .message_handler_set()
482 .lock()
483 .add_entity_message_handler(
484 message_type_id,
485 entity_type_id,
486 entity_id_extractor,
487 Arc::new(move |entity, envelope, _, cx| {
488 let entity = entity.downcast::<E>().unwrap();
489 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
490 handler(entity, *envelope, cx).boxed_local()
491 }),
492 );
493 }
494
495 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
496 let id = (TypeId::of::<E>(), remote_id);
497
498 let mut message_handlers = self.0.client.message_handler_set().lock();
499 if message_handlers
500 .entities_by_type_and_remote_id
501 .contains_key(&id)
502 {
503 panic!("already subscribed to entity");
504 }
505
506 message_handlers.entities_by_type_and_remote_id.insert(
507 id,
508 EntityMessageSubscriber::Entity {
509 handle: entity.downgrade().into(),
510 },
511 );
512 }
513}
514
515fn to_any_envelope<T: EnvelopedMessage>(
516 envelope: &TypedEnvelope<proto::LspQueryResponse>,
517 response: T,
518) -> Box<dyn AnyTypedEnvelope> {
519 Box::new(proto::TypedEnvelope {
520 sender_id: envelope.sender_id,
521 original_sender_id: envelope.original_sender_id,
522 message_id: envelope.message_id,
523 received_at: envelope.received_at,
524 payload: response,
525 }) as Box<_>
526}