1use crate::{
2 error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
3 RequestMessage, TypedEnvelope,
4};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::{
8 future::{BoxFuture, LocalBoxFuture},
9 Future, FutureExt as _,
10};
11use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
12pub use prost::Message;
13use std::{any::TypeId, sync::Arc};
14
15#[derive(Clone)]
16pub struct AnyProtoClient(Arc<dyn ProtoClient>);
17
18pub trait ProtoClient: Send + Sync {
19 fn request(
20 &self,
21 envelope: Envelope,
22 request_type: &'static str,
23 ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
24
25 fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
26
27 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
28
29 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
30}
31
32#[derive(Default)]
33pub struct ProtoMessageHandlerSet {
34 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
35 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
36 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
37 pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
38 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
39}
40
41pub type ProtoMessageHandler = Arc<
42 dyn Send
43 + Sync
44 + Fn(
45 AnyModel,
46 Box<dyn AnyTypedEnvelope>,
47 AnyProtoClient,
48 AsyncAppContext,
49 ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
50>;
51
52impl ProtoMessageHandlerSet {
53 pub fn clear(&mut self) {
54 self.message_handlers.clear();
55 self.models_by_message_type.clear();
56 self.entities_by_type_and_remote_id.clear();
57 self.entity_id_extractors.clear();
58 }
59
60 fn add_message_handler(
61 &mut self,
62 message_type_id: TypeId,
63 model: gpui::AnyWeakModel,
64 handler: ProtoMessageHandler,
65 ) {
66 self.models_by_message_type.insert(message_type_id, model);
67 let prev_handler = self.message_handlers.insert(message_type_id, handler);
68 if prev_handler.is_some() {
69 panic!("registered handler for the same message twice");
70 }
71 }
72
73 fn add_entity_message_handler(
74 &mut self,
75 message_type_id: TypeId,
76 model_type_id: TypeId,
77 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
78 handler: ProtoMessageHandler,
79 ) {
80 self.entity_id_extractors
81 .entry(message_type_id)
82 .or_insert(entity_id_extractor);
83 self.entity_types_by_message_type
84 .insert(message_type_id, model_type_id);
85 let prev_handler = self.message_handlers.insert(message_type_id, handler);
86 if prev_handler.is_some() {
87 panic!("registered handler for the same message twice");
88 }
89 }
90
91 pub fn handle_message(
92 this: &parking_lot::Mutex<Self>,
93 message: Box<dyn AnyTypedEnvelope>,
94 client: AnyProtoClient,
95 cx: AsyncAppContext,
96 ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
97 let payload_type_id = message.payload_type_id();
98 let mut this = this.lock();
99 let handler = this.message_handlers.get(&payload_type_id)?.clone();
100 let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
101 entity.upgrade()?
102 } else {
103 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
104 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
105 let entity_id = (extract_entity_id)(message.as_ref());
106
107 match this
108 .entities_by_type_and_remote_id
109 .get_mut(&(entity_type_id, entity_id))?
110 {
111 EntityMessageSubscriber::Pending(pending) => {
112 pending.push(message);
113 return None;
114 }
115 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
116 }
117 };
118 drop(this);
119 Some(handler(entity, message, client, cx))
120 }
121}
122
123pub enum EntityMessageSubscriber {
124 Entity { handle: AnyWeakModel },
125 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
126}
127
128impl<T> From<Arc<T>> for AnyProtoClient
129where
130 T: ProtoClient + 'static,
131{
132 fn from(client: Arc<T>) -> Self {
133 Self(client)
134 }
135}
136
137impl AnyProtoClient {
138 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
139 Self(client)
140 }
141
142 pub fn request<T: RequestMessage>(
143 &self,
144 request: T,
145 ) -> impl Future<Output = anyhow::Result<T::Response>> {
146 let envelope = request.into_envelope(0, None, None);
147 let response = self.0.request(envelope, T::NAME);
148 async move {
149 T::Response::from_envelope(response.await?)
150 .ok_or_else(|| anyhow!("received response of the wrong type"))
151 }
152 }
153
154 pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
155 let envelope = request.into_envelope(0, None, None);
156 self.0.send(envelope, T::NAME)
157 }
158
159 pub fn send_response<T: EnvelopedMessage>(
160 &self,
161 request_id: u32,
162 request: T,
163 ) -> anyhow::Result<()> {
164 let envelope = request.into_envelope(0, Some(request_id), None);
165 self.0.send(envelope, T::NAME)
166 }
167
168 pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
169 where
170 M: RequestMessage,
171 E: 'static,
172 H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
173 F: 'static + Future<Output = anyhow::Result<M::Response>>,
174 {
175 self.0.message_handler_set().lock().add_message_handler(
176 TypeId::of::<M>(),
177 model.into(),
178 Arc::new(move |model, envelope, client, cx| {
179 let model = model.downcast::<E>().unwrap();
180 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
181 let request_id = envelope.message_id();
182 handler(model, *envelope, cx)
183 .then(move |result| async move {
184 match result {
185 Ok(response) => {
186 client.send_response(request_id, response)?;
187 Ok(())
188 }
189 Err(error) => {
190 client.send_response(request_id, error.to_proto())?;
191 Err(error)
192 }
193 }
194 })
195 .boxed_local()
196 }),
197 )
198 }
199
200 pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
201 where
202 M: EnvelopedMessage + RequestMessage + EntityMessage,
203 E: 'static,
204 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
205 F: 'static + Future<Output = anyhow::Result<M::Response>>,
206 {
207 let message_type_id = TypeId::of::<M>();
208 let model_type_id = TypeId::of::<E>();
209 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
210 envelope
211 .as_any()
212 .downcast_ref::<TypedEnvelope<M>>()
213 .unwrap()
214 .payload
215 .remote_entity_id()
216 };
217 self.0
218 .message_handler_set()
219 .lock()
220 .add_entity_message_handler(
221 message_type_id,
222 model_type_id,
223 entity_id_extractor,
224 Arc::new(move |model, envelope, client, cx| {
225 let model = model.downcast::<E>().unwrap();
226 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
227 let request_id = envelope.message_id();
228 handler(model, *envelope, cx)
229 .then(move |result| async move {
230 match result {
231 Ok(response) => {
232 client.send_response(request_id, response)?;
233 Ok(())
234 }
235 Err(error) => {
236 client.send_response(request_id, error.to_proto())?;
237 Err(error)
238 }
239 }
240 })
241 .boxed_local()
242 }),
243 );
244 }
245
246 pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
247 where
248 M: EnvelopedMessage + EntityMessage,
249 E: 'static,
250 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
251 F: 'static + Future<Output = anyhow::Result<()>>,
252 {
253 let message_type_id = TypeId::of::<M>();
254 let model_type_id = TypeId::of::<E>();
255 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
256 envelope
257 .as_any()
258 .downcast_ref::<TypedEnvelope<M>>()
259 .unwrap()
260 .payload
261 .remote_entity_id()
262 };
263 self.0
264 .message_handler_set()
265 .lock()
266 .add_entity_message_handler(
267 message_type_id,
268 model_type_id,
269 entity_id_extractor,
270 Arc::new(move |model, envelope, _, cx| {
271 let model = model.downcast::<E>().unwrap();
272 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
273 handler(model, *envelope, cx).boxed_local()
274 }),
275 );
276 }
277}