1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::{ConnectionPool, ZedVersion};
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use lazy_static::lazy_static;
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76lazy_static! {
77 static ref METRIC_CONNECTIONS: IntGauge =
78 register_int_gauge!("connections", "number of connections").unwrap();
79 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
80 "shared_projects",
81 "number of open projects with one or more guests"
82 )
83 .unwrap();
84}
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103#[derive(Clone)]
104struct Session {
105 user_id: UserId,
106 connection_id: ConnectionId,
107 db: Arc<tokio::sync::Mutex<DbHandle>>,
108 peer: Arc<Peer>,
109 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
110 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
111 _executor: Executor,
112}
113
114impl Session {
115 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
116 #[cfg(test)]
117 tokio::task::yield_now().await;
118 let guard = self.db.lock().await;
119 #[cfg(test)]
120 tokio::task::yield_now().await;
121 guard
122 }
123
124 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
125 #[cfg(test)]
126 tokio::task::yield_now().await;
127 let guard = self.connection_pool.lock();
128 ConnectionPoolGuard {
129 guard,
130 _not_send: PhantomData,
131 }
132 }
133}
134
135impl fmt::Debug for Session {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.debug_struct("Session")
138 .field("user_id", &self.user_id)
139 .field("connection_id", &self.connection_id)
140 .finish()
141 }
142}
143
144struct DbHandle(Arc<Database>);
145
146impl Deref for DbHandle {
147 type Target = Database;
148
149 fn deref(&self) -> &Self::Target {
150 self.0.as_ref()
151 }
152}
153
154pub struct Server {
155 id: parking_lot::Mutex<ServerId>,
156 peer: Arc<Peer>,
157 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
158 app_state: Arc<AppState>,
159 executor: Executor,
160 handlers: HashMap<TypeId, MessageHandler>,
161 teardown: watch::Sender<()>,
162}
163
164pub(crate) struct ConnectionPoolGuard<'a> {
165 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
166 _not_send: PhantomData<Rc<()>>,
167}
168
169#[derive(Serialize)]
170pub struct ServerSnapshot<'a> {
171 peer: &'a Peer,
172 #[serde(serialize_with = "serialize_deref")]
173 connection_pool: ConnectionPoolGuard<'a>,
174}
175
176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
177where
178 S: Serializer,
179 T: Deref<Target = U>,
180 U: Serialize,
181{
182 Serialize::serialize(value.deref(), serializer)
183}
184
185impl Server {
186 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
187 let mut server = Self {
188 id: parking_lot::Mutex::new(id),
189 peer: Peer::new(id.0 as u32),
190 app_state,
191 executor,
192 connection_pool: Default::default(),
193 handlers: Default::default(),
194 teardown: watch::channel(()).0,
195 };
196
197 server
198 .add_request_handler(ping)
199 .add_request_handler(create_room)
200 .add_request_handler(join_room)
201 .add_request_handler(rejoin_room)
202 .add_request_handler(leave_room)
203 .add_request_handler(set_room_participant_role)
204 .add_request_handler(call)
205 .add_request_handler(cancel_call)
206 .add_message_handler(decline_call)
207 .add_request_handler(update_participant_location)
208 .add_request_handler(share_project)
209 .add_message_handler(unshare_project)
210 .add_request_handler(join_project)
211 .add_message_handler(leave_project)
212 .add_request_handler(update_project)
213 .add_request_handler(update_worktree)
214 .add_message_handler(start_language_server)
215 .add_message_handler(update_language_server)
216 .add_message_handler(update_diagnostic_summary)
217 .add_message_handler(update_worktree_settings)
218 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
219 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
220 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
222 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
224 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
225 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
226 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
227 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
228 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
229 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
231 .add_request_handler(
232 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
233 )
234 .add_request_handler(
235 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
236 )
237 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
238 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
239 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
240 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
241 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
242 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
243 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
244 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
245 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
249 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
250 .add_message_handler(create_buffer_for_peer)
251 .add_request_handler(update_buffer)
252 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
253 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
254 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
257 .add_request_handler(get_users)
258 .add_request_handler(fuzzy_search_users)
259 .add_request_handler(request_contact)
260 .add_request_handler(remove_contact)
261 .add_request_handler(respond_to_contact_request)
262 .add_request_handler(create_channel)
263 .add_request_handler(delete_channel)
264 .add_request_handler(invite_channel_member)
265 .add_request_handler(remove_channel_member)
266 .add_request_handler(set_channel_member_role)
267 .add_request_handler(set_channel_visibility)
268 .add_request_handler(rename_channel)
269 .add_request_handler(join_channel_buffer)
270 .add_request_handler(leave_channel_buffer)
271 .add_message_handler(update_channel_buffer)
272 .add_request_handler(rejoin_channel_buffers)
273 .add_request_handler(get_channel_members)
274 .add_request_handler(respond_to_channel_invite)
275 .add_request_handler(join_channel)
276 .add_request_handler(join_channel_chat)
277 .add_message_handler(leave_channel_chat)
278 .add_request_handler(send_channel_message)
279 .add_request_handler(remove_channel_message)
280 .add_request_handler(get_channel_messages)
281 .add_request_handler(get_channel_messages_by_id)
282 .add_request_handler(get_notifications)
283 .add_request_handler(mark_notification_as_read)
284 .add_request_handler(move_channel)
285 .add_request_handler(follow)
286 .add_message_handler(unfollow)
287 .add_message_handler(update_followers)
288 .add_request_handler(get_private_user_info)
289 .add_message_handler(acknowledge_channel_message)
290 .add_message_handler(acknowledge_buffer_version);
291
292 Arc::new(server)
293 }
294
295 pub async fn start(&self) -> Result<()> {
296 let server_id = *self.id.lock();
297 let app_state = self.app_state.clone();
298 let peer = self.peer.clone();
299 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
300 let pool = self.connection_pool.clone();
301 let live_kit_client = self.app_state.live_kit_client.clone();
302
303 let span = info_span!("start server");
304 self.executor.spawn_detached(
305 async move {
306 tracing::info!("waiting for cleanup timeout");
307 timeout.await;
308 tracing::info!("cleanup timeout expired, retrieving stale rooms");
309 if let Some((room_ids, channel_ids)) = app_state
310 .db
311 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
312 .await
313 .trace_err()
314 {
315 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
316 tracing::info!(
317 stale_channel_buffer_count = channel_ids.len(),
318 "retrieved stale channel buffers"
319 );
320
321 for channel_id in channel_ids {
322 if let Some(refreshed_channel_buffer) = app_state
323 .db
324 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
325 .await
326 .trace_err()
327 {
328 for connection_id in refreshed_channel_buffer.connection_ids {
329 peer.send(
330 connection_id,
331 proto::UpdateChannelBufferCollaborators {
332 channel_id: channel_id.to_proto(),
333 collaborators: refreshed_channel_buffer
334 .collaborators
335 .clone(),
336 },
337 )
338 .trace_err();
339 }
340 }
341 }
342
343 for room_id in room_ids {
344 let mut contacts_to_update = HashSet::default();
345 let mut canceled_calls_to_user_ids = Vec::new();
346 let mut live_kit_room = String::new();
347 let mut delete_live_kit_room = false;
348
349 if let Some(mut refreshed_room) = app_state
350 .db
351 .clear_stale_room_participants(room_id, server_id)
352 .await
353 .trace_err()
354 {
355 tracing::info!(
356 room_id = room_id.0,
357 new_participant_count = refreshed_room.room.participants.len(),
358 "refreshed room"
359 );
360 room_updated(&refreshed_room.room, &peer);
361 if let Some(channel_id) = refreshed_room.channel_id {
362 channel_updated(
363 channel_id,
364 &refreshed_room.room,
365 &refreshed_room.channel_members,
366 &peer,
367 &*pool.lock(),
368 );
369 }
370 contacts_to_update
371 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
372 contacts_to_update
373 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
374 canceled_calls_to_user_ids =
375 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
376 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
377 delete_live_kit_room = refreshed_room.room.participants.is_empty();
378 }
379
380 {
381 let pool = pool.lock();
382 for canceled_user_id in canceled_calls_to_user_ids {
383 for connection_id in pool.user_connection_ids(canceled_user_id) {
384 peer.send(
385 connection_id,
386 proto::CallCanceled {
387 room_id: room_id.to_proto(),
388 },
389 )
390 .trace_err();
391 }
392 }
393 }
394
395 for user_id in contacts_to_update {
396 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
397 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
398 if let Some((busy, contacts)) = busy.zip(contacts) {
399 let pool = pool.lock();
400 let updated_contact = contact_for_user(user_id, busy, &pool);
401 for contact in contacts {
402 if let db::Contact::Accepted {
403 user_id: contact_user_id,
404 ..
405 } = contact
406 {
407 for contact_conn_id in
408 pool.user_connection_ids(contact_user_id)
409 {
410 peer.send(
411 contact_conn_id,
412 proto::UpdateContacts {
413 contacts: vec![updated_contact.clone()],
414 remove_contacts: Default::default(),
415 incoming_requests: Default::default(),
416 remove_incoming_requests: Default::default(),
417 outgoing_requests: Default::default(),
418 remove_outgoing_requests: Default::default(),
419 },
420 )
421 .trace_err();
422 }
423 }
424 }
425 }
426 }
427
428 if let Some(live_kit) = live_kit_client.as_ref() {
429 if delete_live_kit_room {
430 live_kit.delete_room(live_kit_room).await.trace_err();
431 }
432 }
433 }
434 }
435
436 app_state
437 .db
438 .delete_stale_servers(&app_state.config.zed_environment, server_id)
439 .await
440 .trace_err();
441 }
442 .instrument(span),
443 );
444 Ok(())
445 }
446
447 pub fn teardown(&self) {
448 self.peer.teardown();
449 self.connection_pool.lock().reset();
450 let _ = self.teardown.send(());
451 }
452
453 #[cfg(test)]
454 pub fn reset(&self, id: ServerId) {
455 self.teardown();
456 *self.id.lock() = id;
457 self.peer.reset(id.0 as u32);
458 }
459
460 #[cfg(test)]
461 pub fn id(&self) -> ServerId {
462 *self.id.lock()
463 }
464
465 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
466 where
467 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
468 Fut: 'static + Send + Future<Output = Result<()>>,
469 M: EnvelopedMessage,
470 {
471 let prev_handler = self.handlers.insert(
472 TypeId::of::<M>(),
473 Box::new(move |envelope, session| {
474 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
475 let span = info_span!(
476 "handle message",
477 payload_type = envelope.payload_type_name()
478 );
479 span.in_scope(|| {
480 tracing::info!(
481 payload_type = envelope.payload_type_name(),
482 "message received"
483 );
484 });
485 let start_time = Instant::now();
486 let future = (handler)(*envelope, session);
487 async move {
488 let result = future.await;
489 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
490 match result {
491 Err(error) => {
492 tracing::error!(%error, ?duration_ms, "error handling message")
493 }
494 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
495 }
496 }
497 .instrument(span)
498 .boxed()
499 }),
500 );
501 if prev_handler.is_some() {
502 panic!("registered a handler for the same message twice");
503 }
504 self
505 }
506
507 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
508 where
509 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
510 Fut: 'static + Send + Future<Output = Result<()>>,
511 M: EnvelopedMessage,
512 {
513 self.add_handler(move |envelope, session| handler(envelope.payload, session));
514 self
515 }
516
517 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
518 where
519 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
520 Fut: Send + Future<Output = Result<()>>,
521 M: RequestMessage,
522 {
523 let handler = Arc::new(handler);
524 self.add_handler(move |envelope, session| {
525 let receipt = envelope.receipt();
526 let handler = handler.clone();
527 async move {
528 let peer = session.peer.clone();
529 let responded = Arc::new(AtomicBool::default());
530 let response = Response {
531 peer: peer.clone(),
532 responded: responded.clone(),
533 receipt,
534 };
535 match (handler)(envelope.payload, response, session).await {
536 Ok(()) => {
537 if responded.load(std::sync::atomic::Ordering::SeqCst) {
538 Ok(())
539 } else {
540 Err(anyhow!("handler did not send a response"))?
541 }
542 }
543 Err(error) => {
544 let proto_err = match &error {
545 Error::Internal(err) => err.to_proto(),
546 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
547 };
548 peer.respond_with_error(receipt, proto_err)?;
549 Err(error)
550 }
551 }
552 }
553 })
554 }
555
556 pub fn handle_connection(
557 self: &Arc<Self>,
558 connection: Connection,
559 address: String,
560 user: User,
561 zed_version: ZedVersion,
562 impersonator: Option<User>,
563 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
564 executor: Executor,
565 ) -> impl Future<Output = Result<()>> {
566 let this = self.clone();
567 let user_id = user.id;
568 let login = user.github_login;
569 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
570 if let Some(impersonator) = impersonator {
571 span.record("impersonator", &impersonator.github_login);
572 }
573 let mut teardown = self.teardown.subscribe();
574 async move {
575 let (connection_id, handle_io, mut incoming_rx) = this
576 .peer
577 .add_connection(connection, {
578 let executor = executor.clone();
579 move |duration| executor.sleep(duration)
580 });
581
582 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
583 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
584 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
585
586 if let Some(send_connection_id) = send_connection_id.take() {
587 let _ = send_connection_id.send(connection_id);
588 }
589
590 if !user.connected_once {
591 this.peer.send(connection_id, proto::ShowContacts {})?;
592 this.app_state.db.set_user_connected_once(user_id, true).await?;
593 }
594
595 let (contacts, channels_for_user, channel_invites) = future::try_join3(
596 this.app_state.db.get_contacts(user_id),
597 this.app_state.db.get_channels_for_user(user_id),
598 this.app_state.db.get_channel_invites_for_user(user_id),
599 ).await?;
600
601 {
602 let mut pool = this.connection_pool.lock();
603 pool.add_connection(connection_id, user_id, user.admin, zed_version);
604 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
605 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
606 this.peer.send(connection_id, build_channels_update(
607 channels_for_user,
608 channel_invites
609 ))?;
610 }
611
612 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
613 this.peer.send(connection_id, incoming_call)?;
614 }
615
616 let session = Session {
617 user_id,
618 connection_id,
619 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
620 peer: this.peer.clone(),
621 connection_pool: this.connection_pool.clone(),
622 live_kit_client: this.app_state.live_kit_client.clone(),
623 _executor: executor.clone()
624 };
625 update_user_contacts(user_id, &session).await?;
626
627 let handle_io = handle_io.fuse();
628 futures::pin_mut!(handle_io);
629
630 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
631 // This prevents deadlocks when e.g., client A performs a request to client B and
632 // client B performs a request to client A. If both clients stop processing further
633 // messages until their respective request completes, they won't have a chance to
634 // respond to the other client's request and cause a deadlock.
635 //
636 // This arrangement ensures we will attempt to process earlier messages first, but fall
637 // back to processing messages arrived later in the spirit of making progress.
638 let mut foreground_message_handlers = FuturesUnordered::new();
639 let concurrent_handlers = Arc::new(Semaphore::new(256));
640 loop {
641 let next_message = async {
642 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
643 let message = incoming_rx.next().await;
644 (permit, message)
645 }.fuse();
646 futures::pin_mut!(next_message);
647 futures::select_biased! {
648 _ = teardown.changed().fuse() => return Ok(()),
649 result = handle_io => {
650 if let Err(error) = result {
651 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
652 }
653 break;
654 }
655 _ = foreground_message_handlers.next() => {}
656 next_message = next_message => {
657 let (permit, message) = next_message;
658 if let Some(message) = message {
659 let type_name = message.payload_type_name();
660 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
661 let span_enter = span.enter();
662 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
663 let is_background = message.is_background();
664 let handle_message = (handler)(message, session.clone());
665 drop(span_enter);
666
667 let handle_message = async move {
668 handle_message.await;
669 drop(permit);
670 }.instrument(span);
671 if is_background {
672 executor.spawn_detached(handle_message);
673 } else {
674 foreground_message_handlers.push(handle_message);
675 }
676 } else {
677 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
678 }
679 } else {
680 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
681 break;
682 }
683 }
684 }
685 }
686
687 drop(foreground_message_handlers);
688 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
689 if let Err(error) = connection_lost(session, teardown, executor).await {
690 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
691 }
692
693 Ok(())
694 }.instrument(span)
695 }
696
697 pub async fn invite_code_redeemed(
698 self: &Arc<Self>,
699 inviter_id: UserId,
700 invitee_id: UserId,
701 ) -> Result<()> {
702 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
703 if let Some(code) = &user.invite_code {
704 let pool = self.connection_pool.lock();
705 let invitee_contact = contact_for_user(invitee_id, false, &pool);
706 for connection_id in pool.user_connection_ids(inviter_id) {
707 self.peer.send(
708 connection_id,
709 proto::UpdateContacts {
710 contacts: vec![invitee_contact.clone()],
711 ..Default::default()
712 },
713 )?;
714 self.peer.send(
715 connection_id,
716 proto::UpdateInviteInfo {
717 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
718 count: user.invite_count as u32,
719 },
720 )?;
721 }
722 }
723 }
724 Ok(())
725 }
726
727 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
728 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
729 if let Some(invite_code) = &user.invite_code {
730 let pool = self.connection_pool.lock();
731 for connection_id in pool.user_connection_ids(user_id) {
732 self.peer.send(
733 connection_id,
734 proto::UpdateInviteInfo {
735 url: format!(
736 "{}{}",
737 self.app_state.config.invite_link_prefix, invite_code
738 ),
739 count: user.invite_count as u32,
740 },
741 )?;
742 }
743 }
744 }
745 Ok(())
746 }
747
748 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
749 ServerSnapshot {
750 connection_pool: ConnectionPoolGuard {
751 guard: self.connection_pool.lock(),
752 _not_send: PhantomData,
753 },
754 peer: &self.peer,
755 }
756 }
757}
758
759impl<'a> Deref for ConnectionPoolGuard<'a> {
760 type Target = ConnectionPool;
761
762 fn deref(&self) -> &Self::Target {
763 &*self.guard
764 }
765}
766
767impl<'a> DerefMut for ConnectionPoolGuard<'a> {
768 fn deref_mut(&mut self) -> &mut Self::Target {
769 &mut *self.guard
770 }
771}
772
773impl<'a> Drop for ConnectionPoolGuard<'a> {
774 fn drop(&mut self) {
775 #[cfg(test)]
776 self.check_invariants();
777 }
778}
779
780fn broadcast<F>(
781 sender_id: Option<ConnectionId>,
782 receiver_ids: impl IntoIterator<Item = ConnectionId>,
783 mut f: F,
784) where
785 F: FnMut(ConnectionId) -> anyhow::Result<()>,
786{
787 for receiver_id in receiver_ids {
788 if Some(receiver_id) != sender_id {
789 if let Err(error) = f(receiver_id) {
790 tracing::error!("failed to send to {:?} {}", receiver_id, error);
791 }
792 }
793 }
794}
795
796lazy_static! {
797 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
798 static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
799}
800
801pub struct ProtocolVersion(u32);
802
803impl Header for ProtocolVersion {
804 fn name() -> &'static HeaderName {
805 &ZED_PROTOCOL_VERSION
806 }
807
808 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
809 where
810 Self: Sized,
811 I: Iterator<Item = &'i axum::http::HeaderValue>,
812 {
813 let version = values
814 .next()
815 .ok_or_else(axum::headers::Error::invalid)?
816 .to_str()
817 .map_err(|_| axum::headers::Error::invalid())?
818 .parse()
819 .map_err(|_| axum::headers::Error::invalid())?;
820 Ok(Self(version))
821 }
822
823 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
824 values.extend([self.0.to_string().parse().unwrap()]);
825 }
826}
827
828pub struct AppVersionHeader(SemanticVersion);
829impl Header for AppVersionHeader {
830 fn name() -> &'static HeaderName {
831 &ZED_APP_VERSION
832 }
833
834 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
835 where
836 Self: Sized,
837 I: Iterator<Item = &'i axum::http::HeaderValue>,
838 {
839 let version = values
840 .next()
841 .ok_or_else(axum::headers::Error::invalid)?
842 .to_str()
843 .map_err(|_| axum::headers::Error::invalid())?
844 .parse()
845 .map_err(|_| axum::headers::Error::invalid())?;
846 Ok(Self(version))
847 }
848
849 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
850 values.extend([self.0.to_string().parse().unwrap()]);
851 }
852}
853
854pub fn routes(server: Arc<Server>) -> Router<Body> {
855 Router::new()
856 .route("/rpc", get(handle_websocket_request))
857 .layer(
858 ServiceBuilder::new()
859 .layer(Extension(server.app_state.clone()))
860 .layer(middleware::from_fn(auth::validate_header)),
861 )
862 .route("/metrics", get(handle_metrics))
863 .layer(Extension(server))
864}
865
866pub async fn handle_websocket_request(
867 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
868 app_version_header: Option<TypedHeader<AppVersionHeader>>,
869 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
870 Extension(server): Extension<Arc<Server>>,
871 Extension(user): Extension<User>,
872 Extension(impersonator): Extension<Impersonator>,
873 ws: WebSocketUpgrade,
874) -> axum::response::Response {
875 if protocol_version != rpc::PROTOCOL_VERSION {
876 return (
877 StatusCode::UPGRADE_REQUIRED,
878 "client must be upgraded".to_string(),
879 )
880 .into_response();
881 }
882
883 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
884 return (
885 StatusCode::UPGRADE_REQUIRED,
886 "no version header found".to_string(),
887 )
888 .into_response();
889 };
890
891 if !version.is_supported() {
892 return (
893 StatusCode::UPGRADE_REQUIRED,
894 "client must be upgraded".to_string(),
895 )
896 .into_response();
897 }
898
899 let socket_address = socket_address.to_string();
900 ws.on_upgrade(move |socket| {
901 use util::ResultExt;
902 let socket = socket
903 .map_ok(to_tungstenite_message)
904 .err_into()
905 .with(|message| async move { Ok(to_axum_message(message)) });
906 let connection = Connection::new(Box::pin(socket));
907 async move {
908 server
909 .handle_connection(
910 connection,
911 socket_address,
912 user,
913 version,
914 impersonator.0,
915 None,
916 Executor::Production,
917 )
918 .await
919 .log_err();
920 }
921 })
922}
923
924pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
925 let connections = server
926 .connection_pool
927 .lock()
928 .connections()
929 .filter(|connection| !connection.admin)
930 .count();
931
932 METRIC_CONNECTIONS.set(connections as _);
933
934 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
935 METRIC_SHARED_PROJECTS.set(shared_projects as _);
936
937 let encoder = prometheus::TextEncoder::new();
938 let metric_families = prometheus::gather();
939 let encoded_metrics = encoder
940 .encode_to_string(&metric_families)
941 .map_err(|err| anyhow!("{}", err))?;
942 Ok(encoded_metrics)
943}
944
945#[instrument(err, skip(executor))]
946async fn connection_lost(
947 session: Session,
948 mut teardown: watch::Receiver<()>,
949 executor: Executor,
950) -> Result<()> {
951 session.peer.disconnect(session.connection_id);
952 session
953 .connection_pool()
954 .await
955 .remove_connection(session.connection_id)?;
956
957 session
958 .db()
959 .await
960 .connection_lost(session.connection_id)
961 .await
962 .trace_err();
963
964 futures::select_biased! {
965 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
966 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
967 leave_room_for_session(&session).await.trace_err();
968 leave_channel_buffers_for_session(&session)
969 .await
970 .trace_err();
971
972 if !session
973 .connection_pool()
974 .await
975 .is_user_online(session.user_id)
976 {
977 let db = session.db().await;
978 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
979 room_updated(&room, &session.peer);
980 }
981 }
982
983 update_user_contacts(session.user_id, &session).await?;
984 }
985 _ = teardown.changed().fuse() => {}
986 }
987
988 Ok(())
989}
990
991/// Acknowledges a ping from a client, used to keep the connection alive.
992async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
993 response.send(proto::Ack {})?;
994 Ok(())
995}
996
997/// Creates a new room for calling (outside of channels)
998async fn create_room(
999 _request: proto::CreateRoom,
1000 response: Response<proto::CreateRoom>,
1001 session: Session,
1002) -> Result<()> {
1003 let live_kit_room = nanoid::nanoid!(30);
1004
1005 let live_kit_connection_info = {
1006 let live_kit_room = live_kit_room.clone();
1007 let live_kit = session.live_kit_client.as_ref();
1008
1009 util::async_maybe!({
1010 let live_kit = live_kit?;
1011
1012 let token = live_kit
1013 .room_token(&live_kit_room, &session.user_id.to_string())
1014 .trace_err()?;
1015
1016 Some(proto::LiveKitConnectionInfo {
1017 server_url: live_kit.url().into(),
1018 token,
1019 can_publish: true,
1020 })
1021 })
1022 }
1023 .await;
1024
1025 let room = session
1026 .db()
1027 .await
1028 .create_room(session.user_id, session.connection_id, &live_kit_room)
1029 .await?;
1030
1031 response.send(proto::CreateRoomResponse {
1032 room: Some(room.clone()),
1033 live_kit_connection_info,
1034 })?;
1035
1036 update_user_contacts(session.user_id, &session).await?;
1037 Ok(())
1038}
1039
1040/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1041async fn join_room(
1042 request: proto::JoinRoom,
1043 response: Response<proto::JoinRoom>,
1044 session: Session,
1045) -> Result<()> {
1046 let room_id = RoomId::from_proto(request.id);
1047
1048 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1049
1050 if let Some(channel_id) = channel_id {
1051 return join_channel_internal(channel_id, Box::new(response), session).await;
1052 }
1053
1054 let joined_room = {
1055 let room = session
1056 .db()
1057 .await
1058 .join_room(room_id, session.user_id, session.connection_id)
1059 .await?;
1060 room_updated(&room.room, &session.peer);
1061 room.into_inner()
1062 };
1063
1064 for connection_id in session
1065 .connection_pool()
1066 .await
1067 .user_connection_ids(session.user_id)
1068 {
1069 session
1070 .peer
1071 .send(
1072 connection_id,
1073 proto::CallCanceled {
1074 room_id: room_id.to_proto(),
1075 },
1076 )
1077 .trace_err();
1078 }
1079
1080 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1081 if let Some(token) = live_kit
1082 .room_token(
1083 &joined_room.room.live_kit_room,
1084 &session.user_id.to_string(),
1085 )
1086 .trace_err()
1087 {
1088 Some(proto::LiveKitConnectionInfo {
1089 server_url: live_kit.url().into(),
1090 token,
1091 can_publish: true,
1092 })
1093 } else {
1094 None
1095 }
1096 } else {
1097 None
1098 };
1099
1100 response.send(proto::JoinRoomResponse {
1101 room: Some(joined_room.room),
1102 channel_id: None,
1103 live_kit_connection_info,
1104 })?;
1105
1106 update_user_contacts(session.user_id, &session).await?;
1107 Ok(())
1108}
1109
1110/// Rejoin room is used to reconnect to a room after connection errors.
1111async fn rejoin_room(
1112 request: proto::RejoinRoom,
1113 response: Response<proto::RejoinRoom>,
1114 session: Session,
1115) -> Result<()> {
1116 let room;
1117 let channel_id;
1118 let channel_members;
1119 {
1120 let mut rejoined_room = session
1121 .db()
1122 .await
1123 .rejoin_room(request, session.user_id, session.connection_id)
1124 .await?;
1125
1126 response.send(proto::RejoinRoomResponse {
1127 room: Some(rejoined_room.room.clone()),
1128 reshared_projects: rejoined_room
1129 .reshared_projects
1130 .iter()
1131 .map(|project| proto::ResharedProject {
1132 id: project.id.to_proto(),
1133 collaborators: project
1134 .collaborators
1135 .iter()
1136 .map(|collaborator| collaborator.to_proto())
1137 .collect(),
1138 })
1139 .collect(),
1140 rejoined_projects: rejoined_room
1141 .rejoined_projects
1142 .iter()
1143 .map(|rejoined_project| proto::RejoinedProject {
1144 id: rejoined_project.id.to_proto(),
1145 worktrees: rejoined_project
1146 .worktrees
1147 .iter()
1148 .map(|worktree| proto::WorktreeMetadata {
1149 id: worktree.id,
1150 root_name: worktree.root_name.clone(),
1151 visible: worktree.visible,
1152 abs_path: worktree.abs_path.clone(),
1153 })
1154 .collect(),
1155 collaborators: rejoined_project
1156 .collaborators
1157 .iter()
1158 .map(|collaborator| collaborator.to_proto())
1159 .collect(),
1160 language_servers: rejoined_project.language_servers.clone(),
1161 })
1162 .collect(),
1163 })?;
1164 room_updated(&rejoined_room.room, &session.peer);
1165
1166 for project in &rejoined_room.reshared_projects {
1167 for collaborator in &project.collaborators {
1168 session
1169 .peer
1170 .send(
1171 collaborator.connection_id,
1172 proto::UpdateProjectCollaborator {
1173 project_id: project.id.to_proto(),
1174 old_peer_id: Some(project.old_connection_id.into()),
1175 new_peer_id: Some(session.connection_id.into()),
1176 },
1177 )
1178 .trace_err();
1179 }
1180
1181 broadcast(
1182 Some(session.connection_id),
1183 project
1184 .collaborators
1185 .iter()
1186 .map(|collaborator| collaborator.connection_id),
1187 |connection_id| {
1188 session.peer.forward_send(
1189 session.connection_id,
1190 connection_id,
1191 proto::UpdateProject {
1192 project_id: project.id.to_proto(),
1193 worktrees: project.worktrees.clone(),
1194 },
1195 )
1196 },
1197 );
1198 }
1199
1200 for project in &rejoined_room.rejoined_projects {
1201 for collaborator in &project.collaborators {
1202 session
1203 .peer
1204 .send(
1205 collaborator.connection_id,
1206 proto::UpdateProjectCollaborator {
1207 project_id: project.id.to_proto(),
1208 old_peer_id: Some(project.old_connection_id.into()),
1209 new_peer_id: Some(session.connection_id.into()),
1210 },
1211 )
1212 .trace_err();
1213 }
1214 }
1215
1216 for project in &mut rejoined_room.rejoined_projects {
1217 for worktree in mem::take(&mut project.worktrees) {
1218 #[cfg(any(test, feature = "test-support"))]
1219 const MAX_CHUNK_SIZE: usize = 2;
1220 #[cfg(not(any(test, feature = "test-support")))]
1221 const MAX_CHUNK_SIZE: usize = 256;
1222
1223 // Stream this worktree's entries.
1224 let message = proto::UpdateWorktree {
1225 project_id: project.id.to_proto(),
1226 worktree_id: worktree.id,
1227 abs_path: worktree.abs_path.clone(),
1228 root_name: worktree.root_name,
1229 updated_entries: worktree.updated_entries,
1230 removed_entries: worktree.removed_entries,
1231 scan_id: worktree.scan_id,
1232 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1233 updated_repositories: worktree.updated_repositories,
1234 removed_repositories: worktree.removed_repositories,
1235 };
1236 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1237 session.peer.send(session.connection_id, update.clone())?;
1238 }
1239
1240 // Stream this worktree's diagnostics.
1241 for summary in worktree.diagnostic_summaries {
1242 session.peer.send(
1243 session.connection_id,
1244 proto::UpdateDiagnosticSummary {
1245 project_id: project.id.to_proto(),
1246 worktree_id: worktree.id,
1247 summary: Some(summary),
1248 },
1249 )?;
1250 }
1251
1252 for settings_file in worktree.settings_files {
1253 session.peer.send(
1254 session.connection_id,
1255 proto::UpdateWorktreeSettings {
1256 project_id: project.id.to_proto(),
1257 worktree_id: worktree.id,
1258 path: settings_file.path,
1259 content: Some(settings_file.content),
1260 },
1261 )?;
1262 }
1263 }
1264
1265 for language_server in &project.language_servers {
1266 session.peer.send(
1267 session.connection_id,
1268 proto::UpdateLanguageServer {
1269 project_id: project.id.to_proto(),
1270 language_server_id: language_server.id,
1271 variant: Some(
1272 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1273 proto::LspDiskBasedDiagnosticsUpdated {},
1274 ),
1275 ),
1276 },
1277 )?;
1278 }
1279 }
1280
1281 let rejoined_room = rejoined_room.into_inner();
1282
1283 room = rejoined_room.room;
1284 channel_id = rejoined_room.channel_id;
1285 channel_members = rejoined_room.channel_members;
1286 }
1287
1288 if let Some(channel_id) = channel_id {
1289 channel_updated(
1290 channel_id,
1291 &room,
1292 &channel_members,
1293 &session.peer,
1294 &*session.connection_pool().await,
1295 );
1296 }
1297
1298 update_user_contacts(session.user_id, &session).await?;
1299 Ok(())
1300}
1301
1302/// leave room disconnects from the room.
1303async fn leave_room(
1304 _: proto::LeaveRoom,
1305 response: Response<proto::LeaveRoom>,
1306 session: Session,
1307) -> Result<()> {
1308 leave_room_for_session(&session).await?;
1309 response.send(proto::Ack {})?;
1310 Ok(())
1311}
1312
1313/// Updates the permissions of someone else in the room.
1314async fn set_room_participant_role(
1315 request: proto::SetRoomParticipantRole,
1316 response: Response<proto::SetRoomParticipantRole>,
1317 session: Session,
1318) -> Result<()> {
1319 let user_id = UserId::from_proto(request.user_id);
1320 let role = ChannelRole::from(request.role());
1321
1322 if role == ChannelRole::Talker {
1323 let pool = session.connection_pool().await;
1324
1325 for connection in pool.user_connections(user_id) {
1326 if !connection.zed_version.supports_talker_role() {
1327 Err(anyhow!(
1328 "This user is on zed {} which does not support unmute",
1329 connection.zed_version
1330 ))?;
1331 }
1332 }
1333 }
1334
1335 let (live_kit_room, can_publish) = {
1336 let room = session
1337 .db()
1338 .await
1339 .set_room_participant_role(
1340 session.user_id,
1341 RoomId::from_proto(request.room_id),
1342 user_id,
1343 role,
1344 )
1345 .await?;
1346
1347 let live_kit_room = room.live_kit_room.clone();
1348 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1349 room_updated(&room, &session.peer);
1350 (live_kit_room, can_publish)
1351 };
1352
1353 if let Some(live_kit) = session.live_kit_client.as_ref() {
1354 live_kit
1355 .update_participant(
1356 live_kit_room.clone(),
1357 request.user_id.to_string(),
1358 live_kit_server::proto::ParticipantPermission {
1359 can_subscribe: true,
1360 can_publish,
1361 can_publish_data: can_publish,
1362 hidden: false,
1363 recorder: false,
1364 },
1365 )
1366 .await
1367 .trace_err();
1368 }
1369
1370 response.send(proto::Ack {})?;
1371 Ok(())
1372}
1373
1374/// Call someone else into the current room
1375async fn call(
1376 request: proto::Call,
1377 response: Response<proto::Call>,
1378 session: Session,
1379) -> Result<()> {
1380 let room_id = RoomId::from_proto(request.room_id);
1381 let calling_user_id = session.user_id;
1382 let calling_connection_id = session.connection_id;
1383 let called_user_id = UserId::from_proto(request.called_user_id);
1384 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1385 if !session
1386 .db()
1387 .await
1388 .has_contact(calling_user_id, called_user_id)
1389 .await?
1390 {
1391 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1392 }
1393
1394 let incoming_call = {
1395 let (room, incoming_call) = &mut *session
1396 .db()
1397 .await
1398 .call(
1399 room_id,
1400 calling_user_id,
1401 calling_connection_id,
1402 called_user_id,
1403 initial_project_id,
1404 )
1405 .await?;
1406 room_updated(&room, &session.peer);
1407 mem::take(incoming_call)
1408 };
1409 update_user_contacts(called_user_id, &session).await?;
1410
1411 let mut calls = session
1412 .connection_pool()
1413 .await
1414 .user_connection_ids(called_user_id)
1415 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1416 .collect::<FuturesUnordered<_>>();
1417
1418 while let Some(call_response) = calls.next().await {
1419 match call_response.as_ref() {
1420 Ok(_) => {
1421 response.send(proto::Ack {})?;
1422 return Ok(());
1423 }
1424 Err(_) => {
1425 call_response.trace_err();
1426 }
1427 }
1428 }
1429
1430 {
1431 let room = session
1432 .db()
1433 .await
1434 .call_failed(room_id, called_user_id)
1435 .await?;
1436 room_updated(&room, &session.peer);
1437 }
1438 update_user_contacts(called_user_id, &session).await?;
1439
1440 Err(anyhow!("failed to ring user"))?
1441}
1442
1443/// Cancel an outgoing call.
1444async fn cancel_call(
1445 request: proto::CancelCall,
1446 response: Response<proto::CancelCall>,
1447 session: Session,
1448) -> Result<()> {
1449 let called_user_id = UserId::from_proto(request.called_user_id);
1450 let room_id = RoomId::from_proto(request.room_id);
1451 {
1452 let room = session
1453 .db()
1454 .await
1455 .cancel_call(room_id, session.connection_id, called_user_id)
1456 .await?;
1457 room_updated(&room, &session.peer);
1458 }
1459
1460 for connection_id in session
1461 .connection_pool()
1462 .await
1463 .user_connection_ids(called_user_id)
1464 {
1465 session
1466 .peer
1467 .send(
1468 connection_id,
1469 proto::CallCanceled {
1470 room_id: room_id.to_proto(),
1471 },
1472 )
1473 .trace_err();
1474 }
1475 response.send(proto::Ack {})?;
1476
1477 update_user_contacts(called_user_id, &session).await?;
1478 Ok(())
1479}
1480
1481/// Decline an incoming call.
1482async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1483 let room_id = RoomId::from_proto(message.room_id);
1484 {
1485 let room = session
1486 .db()
1487 .await
1488 .decline_call(Some(room_id), session.user_id)
1489 .await?
1490 .ok_or_else(|| anyhow!("failed to decline call"))?;
1491 room_updated(&room, &session.peer);
1492 }
1493
1494 for connection_id in session
1495 .connection_pool()
1496 .await
1497 .user_connection_ids(session.user_id)
1498 {
1499 session
1500 .peer
1501 .send(
1502 connection_id,
1503 proto::CallCanceled {
1504 room_id: room_id.to_proto(),
1505 },
1506 )
1507 .trace_err();
1508 }
1509 update_user_contacts(session.user_id, &session).await?;
1510 Ok(())
1511}
1512
1513/// Updates other participants in the room with your current location.
1514async fn update_participant_location(
1515 request: proto::UpdateParticipantLocation,
1516 response: Response<proto::UpdateParticipantLocation>,
1517 session: Session,
1518) -> Result<()> {
1519 let room_id = RoomId::from_proto(request.room_id);
1520 let location = request
1521 .location
1522 .ok_or_else(|| anyhow!("invalid location"))?;
1523
1524 let db = session.db().await;
1525 let room = db
1526 .update_room_participant_location(room_id, session.connection_id, location)
1527 .await?;
1528
1529 room_updated(&room, &session.peer);
1530 response.send(proto::Ack {})?;
1531 Ok(())
1532}
1533
1534/// Share a project into the room.
1535async fn share_project(
1536 request: proto::ShareProject,
1537 response: Response<proto::ShareProject>,
1538 session: Session,
1539) -> Result<()> {
1540 let (project_id, room) = &*session
1541 .db()
1542 .await
1543 .share_project(
1544 RoomId::from_proto(request.room_id),
1545 session.connection_id,
1546 &request.worktrees,
1547 )
1548 .await?;
1549 response.send(proto::ShareProjectResponse {
1550 project_id: project_id.to_proto(),
1551 })?;
1552 room_updated(&room, &session.peer);
1553
1554 Ok(())
1555}
1556
1557/// Unshare a project from the room.
1558async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1559 let project_id = ProjectId::from_proto(message.project_id);
1560
1561 let (room, guest_connection_ids) = &*session
1562 .db()
1563 .await
1564 .unshare_project(project_id, session.connection_id)
1565 .await?;
1566
1567 broadcast(
1568 Some(session.connection_id),
1569 guest_connection_ids.iter().copied(),
1570 |conn_id| session.peer.send(conn_id, message.clone()),
1571 );
1572 room_updated(&room, &session.peer);
1573
1574 Ok(())
1575}
1576
1577/// Join someone elses shared project.
1578async fn join_project(
1579 request: proto::JoinProject,
1580 response: Response<proto::JoinProject>,
1581 session: Session,
1582) -> Result<()> {
1583 let project_id = ProjectId::from_proto(request.project_id);
1584 let guest_user_id = session.user_id;
1585
1586 tracing::info!(%project_id, "join project");
1587
1588 let (project, replica_id) = &mut *session
1589 .db()
1590 .await
1591 .join_project(project_id, session.connection_id)
1592 .await?;
1593
1594 let collaborators = project
1595 .collaborators
1596 .iter()
1597 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1598 .map(|collaborator| collaborator.to_proto())
1599 .collect::<Vec<_>>();
1600
1601 let worktrees = project
1602 .worktrees
1603 .iter()
1604 .map(|(id, worktree)| proto::WorktreeMetadata {
1605 id: *id,
1606 root_name: worktree.root_name.clone(),
1607 visible: worktree.visible,
1608 abs_path: worktree.abs_path.clone(),
1609 })
1610 .collect::<Vec<_>>();
1611
1612 for collaborator in &collaborators {
1613 session
1614 .peer
1615 .send(
1616 collaborator.peer_id.unwrap().into(),
1617 proto::AddProjectCollaborator {
1618 project_id: project_id.to_proto(),
1619 collaborator: Some(proto::Collaborator {
1620 peer_id: Some(session.connection_id.into()),
1621 replica_id: replica_id.0 as u32,
1622 user_id: guest_user_id.to_proto(),
1623 }),
1624 },
1625 )
1626 .trace_err();
1627 }
1628
1629 // First, we send the metadata associated with each worktree.
1630 response.send(proto::JoinProjectResponse {
1631 worktrees: worktrees.clone(),
1632 replica_id: replica_id.0 as u32,
1633 collaborators: collaborators.clone(),
1634 language_servers: project.language_servers.clone(),
1635 })?;
1636
1637 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1638 #[cfg(any(test, feature = "test-support"))]
1639 const MAX_CHUNK_SIZE: usize = 2;
1640 #[cfg(not(any(test, feature = "test-support")))]
1641 const MAX_CHUNK_SIZE: usize = 256;
1642
1643 // Stream this worktree's entries.
1644 let message = proto::UpdateWorktree {
1645 project_id: project_id.to_proto(),
1646 worktree_id,
1647 abs_path: worktree.abs_path.clone(),
1648 root_name: worktree.root_name,
1649 updated_entries: worktree.entries,
1650 removed_entries: Default::default(),
1651 scan_id: worktree.scan_id,
1652 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1653 updated_repositories: worktree.repository_entries.into_values().collect(),
1654 removed_repositories: Default::default(),
1655 };
1656 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1657 session.peer.send(session.connection_id, update.clone())?;
1658 }
1659
1660 // Stream this worktree's diagnostics.
1661 for summary in worktree.diagnostic_summaries {
1662 session.peer.send(
1663 session.connection_id,
1664 proto::UpdateDiagnosticSummary {
1665 project_id: project_id.to_proto(),
1666 worktree_id: worktree.id,
1667 summary: Some(summary),
1668 },
1669 )?;
1670 }
1671
1672 for settings_file in worktree.settings_files {
1673 session.peer.send(
1674 session.connection_id,
1675 proto::UpdateWorktreeSettings {
1676 project_id: project_id.to_proto(),
1677 worktree_id: worktree.id,
1678 path: settings_file.path,
1679 content: Some(settings_file.content),
1680 },
1681 )?;
1682 }
1683 }
1684
1685 for language_server in &project.language_servers {
1686 session.peer.send(
1687 session.connection_id,
1688 proto::UpdateLanguageServer {
1689 project_id: project_id.to_proto(),
1690 language_server_id: language_server.id,
1691 variant: Some(
1692 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1693 proto::LspDiskBasedDiagnosticsUpdated {},
1694 ),
1695 ),
1696 },
1697 )?;
1698 }
1699
1700 Ok(())
1701}
1702
1703/// Leave someone elses shared project.
1704async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1705 let sender_id = session.connection_id;
1706 let project_id = ProjectId::from_proto(request.project_id);
1707
1708 let (room, project) = &*session
1709 .db()
1710 .await
1711 .leave_project(project_id, sender_id)
1712 .await?;
1713 tracing::info!(
1714 %project_id,
1715 host_user_id = %project.host_user_id,
1716 host_connection_id = ?project.host_connection_id,
1717 "leave project"
1718 );
1719
1720 project_left(&project, &session);
1721 room_updated(&room, &session.peer);
1722
1723 Ok(())
1724}
1725
1726/// Updates other participants with changes to the project
1727async fn update_project(
1728 request: proto::UpdateProject,
1729 response: Response<proto::UpdateProject>,
1730 session: Session,
1731) -> Result<()> {
1732 let project_id = ProjectId::from_proto(request.project_id);
1733 let (room, guest_connection_ids) = &*session
1734 .db()
1735 .await
1736 .update_project(project_id, session.connection_id, &request.worktrees)
1737 .await?;
1738 broadcast(
1739 Some(session.connection_id),
1740 guest_connection_ids.iter().copied(),
1741 |connection_id| {
1742 session
1743 .peer
1744 .forward_send(session.connection_id, connection_id, request.clone())
1745 },
1746 );
1747 room_updated(&room, &session.peer);
1748 response.send(proto::Ack {})?;
1749
1750 Ok(())
1751}
1752
1753/// Updates other participants with changes to the worktree
1754async fn update_worktree(
1755 request: proto::UpdateWorktree,
1756 response: Response<proto::UpdateWorktree>,
1757 session: Session,
1758) -> Result<()> {
1759 let guest_connection_ids = session
1760 .db()
1761 .await
1762 .update_worktree(&request, session.connection_id)
1763 .await?;
1764
1765 broadcast(
1766 Some(session.connection_id),
1767 guest_connection_ids.iter().copied(),
1768 |connection_id| {
1769 session
1770 .peer
1771 .forward_send(session.connection_id, connection_id, request.clone())
1772 },
1773 );
1774 response.send(proto::Ack {})?;
1775 Ok(())
1776}
1777
1778/// Updates other participants with changes to the diagnostics
1779async fn update_diagnostic_summary(
1780 message: proto::UpdateDiagnosticSummary,
1781 session: Session,
1782) -> Result<()> {
1783 let guest_connection_ids = session
1784 .db()
1785 .await
1786 .update_diagnostic_summary(&message, session.connection_id)
1787 .await?;
1788
1789 broadcast(
1790 Some(session.connection_id),
1791 guest_connection_ids.iter().copied(),
1792 |connection_id| {
1793 session
1794 .peer
1795 .forward_send(session.connection_id, connection_id, message.clone())
1796 },
1797 );
1798
1799 Ok(())
1800}
1801
1802/// Updates other participants with changes to the worktree settings
1803async fn update_worktree_settings(
1804 message: proto::UpdateWorktreeSettings,
1805 session: Session,
1806) -> Result<()> {
1807 let guest_connection_ids = session
1808 .db()
1809 .await
1810 .update_worktree_settings(&message, session.connection_id)
1811 .await?;
1812
1813 broadcast(
1814 Some(session.connection_id),
1815 guest_connection_ids.iter().copied(),
1816 |connection_id| {
1817 session
1818 .peer
1819 .forward_send(session.connection_id, connection_id, message.clone())
1820 },
1821 );
1822
1823 Ok(())
1824}
1825
1826/// Notify other participants that a language server has started.
1827async fn start_language_server(
1828 request: proto::StartLanguageServer,
1829 session: Session,
1830) -> Result<()> {
1831 let guest_connection_ids = session
1832 .db()
1833 .await
1834 .start_language_server(&request, session.connection_id)
1835 .await?;
1836
1837 broadcast(
1838 Some(session.connection_id),
1839 guest_connection_ids.iter().copied(),
1840 |connection_id| {
1841 session
1842 .peer
1843 .forward_send(session.connection_id, connection_id, request.clone())
1844 },
1845 );
1846 Ok(())
1847}
1848
1849/// Notify other participants that a language server has changed.
1850async fn update_language_server(
1851 request: proto::UpdateLanguageServer,
1852 session: Session,
1853) -> Result<()> {
1854 let project_id = ProjectId::from_proto(request.project_id);
1855 let project_connection_ids = session
1856 .db()
1857 .await
1858 .project_connection_ids(project_id, session.connection_id)
1859 .await?;
1860 broadcast(
1861 Some(session.connection_id),
1862 project_connection_ids.iter().copied(),
1863 |connection_id| {
1864 session
1865 .peer
1866 .forward_send(session.connection_id, connection_id, request.clone())
1867 },
1868 );
1869 Ok(())
1870}
1871
1872/// forward a project request to the host. These requests should be read only
1873/// as guests are allowed to send them.
1874async fn forward_read_only_project_request<T>(
1875 request: T,
1876 response: Response<T>,
1877 session: Session,
1878) -> Result<()>
1879where
1880 T: EntityMessage + RequestMessage,
1881{
1882 let project_id = ProjectId::from_proto(request.remote_entity_id());
1883 let host_connection_id = session
1884 .db()
1885 .await
1886 .host_for_read_only_project_request(project_id, session.connection_id)
1887 .await?;
1888 let payload = session
1889 .peer
1890 .forward_request(session.connection_id, host_connection_id, request)
1891 .await?;
1892 response.send(payload)?;
1893 Ok(())
1894}
1895
1896/// forward a project request to the host. These requests are disallowed
1897/// for guests.
1898async fn forward_mutating_project_request<T>(
1899 request: T,
1900 response: Response<T>,
1901 session: Session,
1902) -> Result<()>
1903where
1904 T: EntityMessage + RequestMessage,
1905{
1906 let project_id = ProjectId::from_proto(request.remote_entity_id());
1907 let host_connection_id = session
1908 .db()
1909 .await
1910 .host_for_mutating_project_request(project_id, session.connection_id)
1911 .await?;
1912 let payload = session
1913 .peer
1914 .forward_request(session.connection_id, host_connection_id, request)
1915 .await?;
1916 response.send(payload)?;
1917 Ok(())
1918}
1919
1920/// Notify other participants that a new buffer has been created
1921async fn create_buffer_for_peer(
1922 request: proto::CreateBufferForPeer,
1923 session: Session,
1924) -> Result<()> {
1925 session
1926 .db()
1927 .await
1928 .check_user_is_project_host(
1929 ProjectId::from_proto(request.project_id),
1930 session.connection_id,
1931 )
1932 .await?;
1933 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1934 session
1935 .peer
1936 .forward_send(session.connection_id, peer_id.into(), request)?;
1937 Ok(())
1938}
1939
1940/// Notify other participants that a buffer has been updated. This is
1941/// allowed for guests as long as the update is limited to selections.
1942async fn update_buffer(
1943 request: proto::UpdateBuffer,
1944 response: Response<proto::UpdateBuffer>,
1945 session: Session,
1946) -> Result<()> {
1947 let project_id = ProjectId::from_proto(request.project_id);
1948 let mut guest_connection_ids;
1949 let mut host_connection_id = None;
1950
1951 let mut requires_write_permission = false;
1952
1953 for op in request.operations.iter() {
1954 match op.variant {
1955 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1956 Some(_) => requires_write_permission = true,
1957 }
1958 }
1959
1960 {
1961 let collaborators = session
1962 .db()
1963 .await
1964 .project_collaborators_for_buffer_update(
1965 project_id,
1966 session.connection_id,
1967 requires_write_permission,
1968 )
1969 .await?;
1970 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1971 for collaborator in collaborators.iter() {
1972 if collaborator.is_host {
1973 host_connection_id = Some(collaborator.connection_id);
1974 } else {
1975 guest_connection_ids.push(collaborator.connection_id);
1976 }
1977 }
1978 }
1979 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1980
1981 broadcast(
1982 Some(session.connection_id),
1983 guest_connection_ids,
1984 |connection_id| {
1985 session
1986 .peer
1987 .forward_send(session.connection_id, connection_id, request.clone())
1988 },
1989 );
1990 if host_connection_id != session.connection_id {
1991 session
1992 .peer
1993 .forward_request(session.connection_id, host_connection_id, request.clone())
1994 .await?;
1995 }
1996
1997 response.send(proto::Ack {})?;
1998 Ok(())
1999}
2000
2001/// Notify other participants that a project has been updated.
2002async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2003 request: T,
2004 session: Session,
2005) -> Result<()> {
2006 let project_id = ProjectId::from_proto(request.remote_entity_id());
2007 let project_connection_ids = session
2008 .db()
2009 .await
2010 .project_connection_ids(project_id, session.connection_id)
2011 .await?;
2012
2013 broadcast(
2014 Some(session.connection_id),
2015 project_connection_ids.iter().copied(),
2016 |connection_id| {
2017 session
2018 .peer
2019 .forward_send(session.connection_id, connection_id, request.clone())
2020 },
2021 );
2022 Ok(())
2023}
2024
2025/// Start following another user in a call.
2026async fn follow(
2027 request: proto::Follow,
2028 response: Response<proto::Follow>,
2029 session: Session,
2030) -> Result<()> {
2031 let room_id = RoomId::from_proto(request.room_id);
2032 let project_id = request.project_id.map(ProjectId::from_proto);
2033 let leader_id = request
2034 .leader_id
2035 .ok_or_else(|| anyhow!("invalid leader id"))?
2036 .into();
2037 let follower_id = session.connection_id;
2038
2039 session
2040 .db()
2041 .await
2042 .check_room_participants(room_id, leader_id, session.connection_id)
2043 .await?;
2044
2045 let response_payload = session
2046 .peer
2047 .forward_request(session.connection_id, leader_id, request)
2048 .await?;
2049 response.send(response_payload)?;
2050
2051 if let Some(project_id) = project_id {
2052 let room = session
2053 .db()
2054 .await
2055 .follow(room_id, project_id, leader_id, follower_id)
2056 .await?;
2057 room_updated(&room, &session.peer);
2058 }
2059
2060 Ok(())
2061}
2062
2063/// Stop following another user in a call.
2064async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2065 let room_id = RoomId::from_proto(request.room_id);
2066 let project_id = request.project_id.map(ProjectId::from_proto);
2067 let leader_id = request
2068 .leader_id
2069 .ok_or_else(|| anyhow!("invalid leader id"))?
2070 .into();
2071 let follower_id = session.connection_id;
2072
2073 session
2074 .db()
2075 .await
2076 .check_room_participants(room_id, leader_id, session.connection_id)
2077 .await?;
2078
2079 session
2080 .peer
2081 .forward_send(session.connection_id, leader_id, request)?;
2082
2083 if let Some(project_id) = project_id {
2084 let room = session
2085 .db()
2086 .await
2087 .unfollow(room_id, project_id, leader_id, follower_id)
2088 .await?;
2089 room_updated(&room, &session.peer);
2090 }
2091
2092 Ok(())
2093}
2094
2095/// Notify everyone following you of your current location.
2096async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2097 let room_id = RoomId::from_proto(request.room_id);
2098 let database = session.db.lock().await;
2099
2100 let connection_ids = if let Some(project_id) = request.project_id {
2101 let project_id = ProjectId::from_proto(project_id);
2102 database
2103 .project_connection_ids(project_id, session.connection_id)
2104 .await?
2105 } else {
2106 database
2107 .room_connection_ids(room_id, session.connection_id)
2108 .await?
2109 };
2110
2111 // For now, don't send view update messages back to that view's current leader.
2112 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2113 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2114 _ => None,
2115 });
2116
2117 for connection_id in connection_ids.iter().cloned() {
2118 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2119 session
2120 .peer
2121 .forward_send(session.connection_id, connection_id, request.clone())?;
2122 }
2123 }
2124 Ok(())
2125}
2126
2127/// Get public data about users.
2128async fn get_users(
2129 request: proto::GetUsers,
2130 response: Response<proto::GetUsers>,
2131 session: Session,
2132) -> Result<()> {
2133 let user_ids = request
2134 .user_ids
2135 .into_iter()
2136 .map(UserId::from_proto)
2137 .collect();
2138 let users = session
2139 .db()
2140 .await
2141 .get_users_by_ids(user_ids)
2142 .await?
2143 .into_iter()
2144 .map(|user| proto::User {
2145 id: user.id.to_proto(),
2146 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2147 github_login: user.github_login,
2148 })
2149 .collect();
2150 response.send(proto::UsersResponse { users })?;
2151 Ok(())
2152}
2153
2154/// Search for users (to invite) buy Github login
2155async fn fuzzy_search_users(
2156 request: proto::FuzzySearchUsers,
2157 response: Response<proto::FuzzySearchUsers>,
2158 session: Session,
2159) -> Result<()> {
2160 let query = request.query;
2161 let users = match query.len() {
2162 0 => vec![],
2163 1 | 2 => session
2164 .db()
2165 .await
2166 .get_user_by_github_login(&query)
2167 .await?
2168 .into_iter()
2169 .collect(),
2170 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2171 };
2172 let users = users
2173 .into_iter()
2174 .filter(|user| user.id != session.user_id)
2175 .map(|user| proto::User {
2176 id: user.id.to_proto(),
2177 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2178 github_login: user.github_login,
2179 })
2180 .collect();
2181 response.send(proto::UsersResponse { users })?;
2182 Ok(())
2183}
2184
2185/// Send a contact request to another user.
2186async fn request_contact(
2187 request: proto::RequestContact,
2188 response: Response<proto::RequestContact>,
2189 session: Session,
2190) -> Result<()> {
2191 let requester_id = session.user_id;
2192 let responder_id = UserId::from_proto(request.responder_id);
2193 if requester_id == responder_id {
2194 return Err(anyhow!("cannot add yourself as a contact"))?;
2195 }
2196
2197 let notifications = session
2198 .db()
2199 .await
2200 .send_contact_request(requester_id, responder_id)
2201 .await?;
2202
2203 // Update outgoing contact requests of requester
2204 let mut update = proto::UpdateContacts::default();
2205 update.outgoing_requests.push(responder_id.to_proto());
2206 for connection_id in session
2207 .connection_pool()
2208 .await
2209 .user_connection_ids(requester_id)
2210 {
2211 session.peer.send(connection_id, update.clone())?;
2212 }
2213
2214 // Update incoming contact requests of responder
2215 let mut update = proto::UpdateContacts::default();
2216 update
2217 .incoming_requests
2218 .push(proto::IncomingContactRequest {
2219 requester_id: requester_id.to_proto(),
2220 });
2221 let connection_pool = session.connection_pool().await;
2222 for connection_id in connection_pool.user_connection_ids(responder_id) {
2223 session.peer.send(connection_id, update.clone())?;
2224 }
2225
2226 send_notifications(&*connection_pool, &session.peer, notifications);
2227
2228 response.send(proto::Ack {})?;
2229 Ok(())
2230}
2231
2232/// Accept or decline a contact request
2233async fn respond_to_contact_request(
2234 request: proto::RespondToContactRequest,
2235 response: Response<proto::RespondToContactRequest>,
2236 session: Session,
2237) -> Result<()> {
2238 let responder_id = session.user_id;
2239 let requester_id = UserId::from_proto(request.requester_id);
2240 let db = session.db().await;
2241 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2242 db.dismiss_contact_notification(responder_id, requester_id)
2243 .await?;
2244 } else {
2245 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2246
2247 let notifications = db
2248 .respond_to_contact_request(responder_id, requester_id, accept)
2249 .await?;
2250 let requester_busy = db.is_user_busy(requester_id).await?;
2251 let responder_busy = db.is_user_busy(responder_id).await?;
2252
2253 let pool = session.connection_pool().await;
2254 // Update responder with new contact
2255 let mut update = proto::UpdateContacts::default();
2256 if accept {
2257 update
2258 .contacts
2259 .push(contact_for_user(requester_id, requester_busy, &pool));
2260 }
2261 update
2262 .remove_incoming_requests
2263 .push(requester_id.to_proto());
2264 for connection_id in pool.user_connection_ids(responder_id) {
2265 session.peer.send(connection_id, update.clone())?;
2266 }
2267
2268 // Update requester with new contact
2269 let mut update = proto::UpdateContacts::default();
2270 if accept {
2271 update
2272 .contacts
2273 .push(contact_for_user(responder_id, responder_busy, &pool));
2274 }
2275 update
2276 .remove_outgoing_requests
2277 .push(responder_id.to_proto());
2278
2279 for connection_id in pool.user_connection_ids(requester_id) {
2280 session.peer.send(connection_id, update.clone())?;
2281 }
2282
2283 send_notifications(&*pool, &session.peer, notifications);
2284 }
2285
2286 response.send(proto::Ack {})?;
2287 Ok(())
2288}
2289
2290/// Remove a contact.
2291async fn remove_contact(
2292 request: proto::RemoveContact,
2293 response: Response<proto::RemoveContact>,
2294 session: Session,
2295) -> Result<()> {
2296 let requester_id = session.user_id;
2297 let responder_id = UserId::from_proto(request.user_id);
2298 let db = session.db().await;
2299 let (contact_accepted, deleted_notification_id) =
2300 db.remove_contact(requester_id, responder_id).await?;
2301
2302 let pool = session.connection_pool().await;
2303 // Update outgoing contact requests of requester
2304 let mut update = proto::UpdateContacts::default();
2305 if contact_accepted {
2306 update.remove_contacts.push(responder_id.to_proto());
2307 } else {
2308 update
2309 .remove_outgoing_requests
2310 .push(responder_id.to_proto());
2311 }
2312 for connection_id in pool.user_connection_ids(requester_id) {
2313 session.peer.send(connection_id, update.clone())?;
2314 }
2315
2316 // Update incoming contact requests of responder
2317 let mut update = proto::UpdateContacts::default();
2318 if contact_accepted {
2319 update.remove_contacts.push(requester_id.to_proto());
2320 } else {
2321 update
2322 .remove_incoming_requests
2323 .push(requester_id.to_proto());
2324 }
2325 for connection_id in pool.user_connection_ids(responder_id) {
2326 session.peer.send(connection_id, update.clone())?;
2327 if let Some(notification_id) = deleted_notification_id {
2328 session.peer.send(
2329 connection_id,
2330 proto::DeleteNotification {
2331 notification_id: notification_id.to_proto(),
2332 },
2333 )?;
2334 }
2335 }
2336
2337 response.send(proto::Ack {})?;
2338 Ok(())
2339}
2340
2341/// Creates a new channel.
2342async fn create_channel(
2343 request: proto::CreateChannel,
2344 response: Response<proto::CreateChannel>,
2345 session: Session,
2346) -> Result<()> {
2347 let db = session.db().await;
2348
2349 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2350 let (channel, owner, channel_members) = db
2351 .create_channel(&request.name, parent_id, session.user_id)
2352 .await?;
2353
2354 response.send(proto::CreateChannelResponse {
2355 channel: Some(channel.to_proto()),
2356 parent_id: request.parent_id,
2357 })?;
2358
2359 let connection_pool = session.connection_pool().await;
2360 if let Some(owner) = owner {
2361 let update = proto::UpdateUserChannels {
2362 channel_memberships: vec![proto::ChannelMembership {
2363 channel_id: owner.channel_id.to_proto(),
2364 role: owner.role.into(),
2365 }],
2366 ..Default::default()
2367 };
2368 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2369 session.peer.send(connection_id, update.clone())?;
2370 }
2371 }
2372
2373 for channel_member in channel_members {
2374 if !channel_member.role.can_see_channel(channel.visibility) {
2375 continue;
2376 }
2377
2378 let update = proto::UpdateChannels {
2379 channels: vec![channel.to_proto()],
2380 ..Default::default()
2381 };
2382 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2383 session.peer.send(connection_id, update.clone())?;
2384 }
2385 }
2386
2387 Ok(())
2388}
2389
2390/// Delete a channel
2391async fn delete_channel(
2392 request: proto::DeleteChannel,
2393 response: Response<proto::DeleteChannel>,
2394 session: Session,
2395) -> Result<()> {
2396 let db = session.db().await;
2397
2398 let channel_id = request.channel_id;
2399 let (removed_channels, member_ids) = db
2400 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2401 .await?;
2402 response.send(proto::Ack {})?;
2403
2404 // Notify members of removed channels
2405 let mut update = proto::UpdateChannels::default();
2406 update
2407 .delete_channels
2408 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2409
2410 let connection_pool = session.connection_pool().await;
2411 for member_id in member_ids {
2412 for connection_id in connection_pool.user_connection_ids(member_id) {
2413 session.peer.send(connection_id, update.clone())?;
2414 }
2415 }
2416
2417 Ok(())
2418}
2419
2420/// Invite someone to join a channel.
2421async fn invite_channel_member(
2422 request: proto::InviteChannelMember,
2423 response: Response<proto::InviteChannelMember>,
2424 session: Session,
2425) -> Result<()> {
2426 let db = session.db().await;
2427 let channel_id = ChannelId::from_proto(request.channel_id);
2428 let invitee_id = UserId::from_proto(request.user_id);
2429 let InviteMemberResult {
2430 channel,
2431 notifications,
2432 } = db
2433 .invite_channel_member(
2434 channel_id,
2435 invitee_id,
2436 session.user_id,
2437 request.role().into(),
2438 )
2439 .await?;
2440
2441 let update = proto::UpdateChannels {
2442 channel_invitations: vec![channel.to_proto()],
2443 ..Default::default()
2444 };
2445
2446 let connection_pool = session.connection_pool().await;
2447 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2448 session.peer.send(connection_id, update.clone())?;
2449 }
2450
2451 send_notifications(&*connection_pool, &session.peer, notifications);
2452
2453 response.send(proto::Ack {})?;
2454 Ok(())
2455}
2456
2457/// remove someone from a channel
2458async fn remove_channel_member(
2459 request: proto::RemoveChannelMember,
2460 response: Response<proto::RemoveChannelMember>,
2461 session: Session,
2462) -> Result<()> {
2463 let db = session.db().await;
2464 let channel_id = ChannelId::from_proto(request.channel_id);
2465 let member_id = UserId::from_proto(request.user_id);
2466
2467 let RemoveChannelMemberResult {
2468 membership_update,
2469 notification_id,
2470 } = db
2471 .remove_channel_member(channel_id, member_id, session.user_id)
2472 .await?;
2473
2474 let connection_pool = &session.connection_pool().await;
2475 notify_membership_updated(
2476 &connection_pool,
2477 membership_update,
2478 member_id,
2479 &session.peer,
2480 );
2481 for connection_id in connection_pool.user_connection_ids(member_id) {
2482 if let Some(notification_id) = notification_id {
2483 session
2484 .peer
2485 .send(
2486 connection_id,
2487 proto::DeleteNotification {
2488 notification_id: notification_id.to_proto(),
2489 },
2490 )
2491 .trace_err();
2492 }
2493 }
2494
2495 response.send(proto::Ack {})?;
2496 Ok(())
2497}
2498
2499/// Toggle the channel between public and private.
2500/// Care is taken to maintain the invariant that public channels only descend from public channels,
2501/// (though members-only channels can appear at any point in the hierarchy).
2502async fn set_channel_visibility(
2503 request: proto::SetChannelVisibility,
2504 response: Response<proto::SetChannelVisibility>,
2505 session: Session,
2506) -> Result<()> {
2507 let db = session.db().await;
2508 let channel_id = ChannelId::from_proto(request.channel_id);
2509 let visibility = request.visibility().into();
2510
2511 let (channel, channel_members) = db
2512 .set_channel_visibility(channel_id, visibility, session.user_id)
2513 .await?;
2514
2515 let connection_pool = session.connection_pool().await;
2516 for member in channel_members {
2517 let update = if member.role.can_see_channel(channel.visibility) {
2518 proto::UpdateChannels {
2519 channels: vec![channel.to_proto()],
2520 ..Default::default()
2521 }
2522 } else {
2523 proto::UpdateChannels {
2524 delete_channels: vec![channel.id.to_proto()],
2525 ..Default::default()
2526 }
2527 };
2528
2529 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2530 session.peer.send(connection_id, update.clone())?;
2531 }
2532 }
2533
2534 response.send(proto::Ack {})?;
2535 Ok(())
2536}
2537
2538/// Alter the role for a user in the channel.
2539async fn set_channel_member_role(
2540 request: proto::SetChannelMemberRole,
2541 response: Response<proto::SetChannelMemberRole>,
2542 session: Session,
2543) -> Result<()> {
2544 let db = session.db().await;
2545 let channel_id = ChannelId::from_proto(request.channel_id);
2546 let member_id = UserId::from_proto(request.user_id);
2547 let result = db
2548 .set_channel_member_role(
2549 channel_id,
2550 session.user_id,
2551 member_id,
2552 request.role().into(),
2553 )
2554 .await?;
2555
2556 match result {
2557 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2558 let connection_pool = session.connection_pool().await;
2559 notify_membership_updated(
2560 &connection_pool,
2561 membership_update,
2562 member_id,
2563 &session.peer,
2564 )
2565 }
2566 db::SetMemberRoleResult::InviteUpdated(channel) => {
2567 let update = proto::UpdateChannels {
2568 channel_invitations: vec![channel.to_proto()],
2569 ..Default::default()
2570 };
2571
2572 for connection_id in session
2573 .connection_pool()
2574 .await
2575 .user_connection_ids(member_id)
2576 {
2577 session.peer.send(connection_id, update.clone())?;
2578 }
2579 }
2580 }
2581
2582 response.send(proto::Ack {})?;
2583 Ok(())
2584}
2585
2586/// Change the name of a channel
2587async fn rename_channel(
2588 request: proto::RenameChannel,
2589 response: Response<proto::RenameChannel>,
2590 session: Session,
2591) -> Result<()> {
2592 let db = session.db().await;
2593 let channel_id = ChannelId::from_proto(request.channel_id);
2594 let (channel, channel_members) = db
2595 .rename_channel(channel_id, session.user_id, &request.name)
2596 .await?;
2597
2598 response.send(proto::RenameChannelResponse {
2599 channel: Some(channel.to_proto()),
2600 })?;
2601
2602 let connection_pool = session.connection_pool().await;
2603 for channel_member in channel_members {
2604 if !channel_member.role.can_see_channel(channel.visibility) {
2605 continue;
2606 }
2607 let update = proto::UpdateChannels {
2608 channels: vec![channel.to_proto()],
2609 ..Default::default()
2610 };
2611 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2612 session.peer.send(connection_id, update.clone())?;
2613 }
2614 }
2615
2616 Ok(())
2617}
2618
2619/// Move a channel to a new parent.
2620async fn move_channel(
2621 request: proto::MoveChannel,
2622 response: Response<proto::MoveChannel>,
2623 session: Session,
2624) -> Result<()> {
2625 let channel_id = ChannelId::from_proto(request.channel_id);
2626 let to = ChannelId::from_proto(request.to);
2627
2628 let (channels, channel_members) = session
2629 .db()
2630 .await
2631 .move_channel(channel_id, to, session.user_id)
2632 .await?;
2633
2634 let connection_pool = session.connection_pool().await;
2635 for member in channel_members {
2636 let channels = channels
2637 .iter()
2638 .filter_map(|channel| {
2639 if member.role.can_see_channel(channel.visibility) {
2640 Some(channel.to_proto())
2641 } else {
2642 None
2643 }
2644 })
2645 .collect::<Vec<_>>();
2646 if channels.is_empty() {
2647 continue;
2648 }
2649
2650 let update = proto::UpdateChannels {
2651 channels,
2652 ..Default::default()
2653 };
2654
2655 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2656 session.peer.send(connection_id, update.clone())?;
2657 }
2658 }
2659
2660 response.send(Ack {})?;
2661 Ok(())
2662}
2663
2664/// Get the list of channel members
2665async fn get_channel_members(
2666 request: proto::GetChannelMembers,
2667 response: Response<proto::GetChannelMembers>,
2668 session: Session,
2669) -> Result<()> {
2670 let db = session.db().await;
2671 let channel_id = ChannelId::from_proto(request.channel_id);
2672 let members = db
2673 .get_channel_participant_details(channel_id, session.user_id)
2674 .await?;
2675 response.send(proto::GetChannelMembersResponse { members })?;
2676 Ok(())
2677}
2678
2679/// Accept or decline a channel invitation.
2680async fn respond_to_channel_invite(
2681 request: proto::RespondToChannelInvite,
2682 response: Response<proto::RespondToChannelInvite>,
2683 session: Session,
2684) -> Result<()> {
2685 let db = session.db().await;
2686 let channel_id = ChannelId::from_proto(request.channel_id);
2687 let RespondToChannelInvite {
2688 membership_update,
2689 notifications,
2690 } = db
2691 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2692 .await?;
2693
2694 let connection_pool = session.connection_pool().await;
2695 if let Some(membership_update) = membership_update {
2696 notify_membership_updated(
2697 &connection_pool,
2698 membership_update,
2699 session.user_id,
2700 &session.peer,
2701 );
2702 } else {
2703 let update = proto::UpdateChannels {
2704 remove_channel_invitations: vec![channel_id.to_proto()],
2705 ..Default::default()
2706 };
2707
2708 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2709 session.peer.send(connection_id, update.clone())?;
2710 }
2711 };
2712
2713 send_notifications(&*connection_pool, &session.peer, notifications);
2714
2715 response.send(proto::Ack {})?;
2716
2717 Ok(())
2718}
2719
2720/// Join the channels' room
2721async fn join_channel(
2722 request: proto::JoinChannel,
2723 response: Response<proto::JoinChannel>,
2724 session: Session,
2725) -> Result<()> {
2726 let channel_id = ChannelId::from_proto(request.channel_id);
2727 join_channel_internal(channel_id, Box::new(response), session).await
2728}
2729
2730trait JoinChannelInternalResponse {
2731 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2732}
2733impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2734 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2735 Response::<proto::JoinChannel>::send(self, result)
2736 }
2737}
2738impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2739 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2740 Response::<proto::JoinRoom>::send(self, result)
2741 }
2742}
2743
2744async fn join_channel_internal(
2745 channel_id: ChannelId,
2746 response: Box<impl JoinChannelInternalResponse>,
2747 session: Session,
2748) -> Result<()> {
2749 let joined_room = {
2750 leave_room_for_session(&session).await?;
2751 let db = session.db().await;
2752
2753 let (joined_room, membership_updated, role) = db
2754 .join_channel(channel_id, session.user_id, session.connection_id)
2755 .await?;
2756
2757 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2758 let (can_publish, token) = if role == ChannelRole::Guest {
2759 (
2760 false,
2761 live_kit
2762 .guest_token(
2763 &joined_room.room.live_kit_room,
2764 &session.user_id.to_string(),
2765 )
2766 .trace_err()?,
2767 )
2768 } else {
2769 (
2770 true,
2771 live_kit
2772 .room_token(
2773 &joined_room.room.live_kit_room,
2774 &session.user_id.to_string(),
2775 )
2776 .trace_err()?,
2777 )
2778 };
2779
2780 Some(LiveKitConnectionInfo {
2781 server_url: live_kit.url().into(),
2782 token,
2783 can_publish,
2784 })
2785 });
2786
2787 response.send(proto::JoinRoomResponse {
2788 room: Some(joined_room.room.clone()),
2789 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2790 live_kit_connection_info,
2791 })?;
2792
2793 let connection_pool = session.connection_pool().await;
2794 if let Some(membership_updated) = membership_updated {
2795 notify_membership_updated(
2796 &connection_pool,
2797 membership_updated,
2798 session.user_id,
2799 &session.peer,
2800 );
2801 }
2802
2803 room_updated(&joined_room.room, &session.peer);
2804
2805 joined_room
2806 };
2807
2808 channel_updated(
2809 channel_id,
2810 &joined_room.room,
2811 &joined_room.channel_members,
2812 &session.peer,
2813 &*session.connection_pool().await,
2814 );
2815
2816 update_user_contacts(session.user_id, &session).await?;
2817 Ok(())
2818}
2819
2820/// Start editing the channel notes
2821async fn join_channel_buffer(
2822 request: proto::JoinChannelBuffer,
2823 response: Response<proto::JoinChannelBuffer>,
2824 session: Session,
2825) -> Result<()> {
2826 let db = session.db().await;
2827 let channel_id = ChannelId::from_proto(request.channel_id);
2828
2829 let open_response = db
2830 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2831 .await?;
2832
2833 let collaborators = open_response.collaborators.clone();
2834 response.send(open_response)?;
2835
2836 let update = UpdateChannelBufferCollaborators {
2837 channel_id: channel_id.to_proto(),
2838 collaborators: collaborators.clone(),
2839 };
2840 channel_buffer_updated(
2841 session.connection_id,
2842 collaborators
2843 .iter()
2844 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2845 &update,
2846 &session.peer,
2847 );
2848
2849 Ok(())
2850}
2851
2852/// Edit the channel notes
2853async fn update_channel_buffer(
2854 request: proto::UpdateChannelBuffer,
2855 session: Session,
2856) -> Result<()> {
2857 let db = session.db().await;
2858 let channel_id = ChannelId::from_proto(request.channel_id);
2859
2860 let (collaborators, non_collaborators, epoch, version) = db
2861 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2862 .await?;
2863
2864 channel_buffer_updated(
2865 session.connection_id,
2866 collaborators,
2867 &proto::UpdateChannelBuffer {
2868 channel_id: channel_id.to_proto(),
2869 operations: request.operations,
2870 },
2871 &session.peer,
2872 );
2873
2874 let pool = &*session.connection_pool().await;
2875
2876 broadcast(
2877 None,
2878 non_collaborators
2879 .iter()
2880 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2881 |peer_id| {
2882 session.peer.send(
2883 peer_id.into(),
2884 proto::UpdateChannels {
2885 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2886 channel_id: channel_id.to_proto(),
2887 epoch: epoch as u64,
2888 version: version.clone(),
2889 }],
2890 ..Default::default()
2891 },
2892 )
2893 },
2894 );
2895
2896 Ok(())
2897}
2898
2899/// Rejoin the channel notes after a connection blip
2900async fn rejoin_channel_buffers(
2901 request: proto::RejoinChannelBuffers,
2902 response: Response<proto::RejoinChannelBuffers>,
2903 session: Session,
2904) -> Result<()> {
2905 let db = session.db().await;
2906 let buffers = db
2907 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2908 .await?;
2909
2910 for rejoined_buffer in &buffers {
2911 let collaborators_to_notify = rejoined_buffer
2912 .buffer
2913 .collaborators
2914 .iter()
2915 .filter_map(|c| Some(c.peer_id?.into()));
2916 channel_buffer_updated(
2917 session.connection_id,
2918 collaborators_to_notify,
2919 &proto::UpdateChannelBufferCollaborators {
2920 channel_id: rejoined_buffer.buffer.channel_id,
2921 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2922 },
2923 &session.peer,
2924 );
2925 }
2926
2927 response.send(proto::RejoinChannelBuffersResponse {
2928 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2929 })?;
2930
2931 Ok(())
2932}
2933
2934/// Stop editing the channel notes
2935async fn leave_channel_buffer(
2936 request: proto::LeaveChannelBuffer,
2937 response: Response<proto::LeaveChannelBuffer>,
2938 session: Session,
2939) -> Result<()> {
2940 let db = session.db().await;
2941 let channel_id = ChannelId::from_proto(request.channel_id);
2942
2943 let left_buffer = db
2944 .leave_channel_buffer(channel_id, session.connection_id)
2945 .await?;
2946
2947 response.send(Ack {})?;
2948
2949 channel_buffer_updated(
2950 session.connection_id,
2951 left_buffer.connections,
2952 &proto::UpdateChannelBufferCollaborators {
2953 channel_id: channel_id.to_proto(),
2954 collaborators: left_buffer.collaborators,
2955 },
2956 &session.peer,
2957 );
2958
2959 Ok(())
2960}
2961
2962fn channel_buffer_updated<T: EnvelopedMessage>(
2963 sender_id: ConnectionId,
2964 collaborators: impl IntoIterator<Item = ConnectionId>,
2965 message: &T,
2966 peer: &Peer,
2967) {
2968 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2969 peer.send(peer_id.into(), message.clone())
2970 });
2971}
2972
2973fn send_notifications(
2974 connection_pool: &ConnectionPool,
2975 peer: &Peer,
2976 notifications: db::NotificationBatch,
2977) {
2978 for (user_id, notification) in notifications {
2979 for connection_id in connection_pool.user_connection_ids(user_id) {
2980 if let Err(error) = peer.send(
2981 connection_id,
2982 proto::AddNotification {
2983 notification: Some(notification.clone()),
2984 },
2985 ) {
2986 tracing::error!(
2987 "failed to send notification to {:?} {}",
2988 connection_id,
2989 error
2990 );
2991 }
2992 }
2993 }
2994}
2995
2996/// Send a message to the channel
2997async fn send_channel_message(
2998 request: proto::SendChannelMessage,
2999 response: Response<proto::SendChannelMessage>,
3000 session: Session,
3001) -> Result<()> {
3002 // Validate the message body.
3003 let body = request.body.trim().to_string();
3004 if body.len() > MAX_MESSAGE_LEN {
3005 return Err(anyhow!("message is too long"))?;
3006 }
3007 if body.is_empty() {
3008 return Err(anyhow!("message can't be blank"))?;
3009 }
3010
3011 // TODO: adjust mentions if body is trimmed
3012
3013 let timestamp = OffsetDateTime::now_utc();
3014 let nonce = request
3015 .nonce
3016 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3017
3018 let channel_id = ChannelId::from_proto(request.channel_id);
3019 let CreatedChannelMessage {
3020 message_id,
3021 participant_connection_ids,
3022 channel_members,
3023 notifications,
3024 } = session
3025 .db()
3026 .await
3027 .create_channel_message(
3028 channel_id,
3029 session.user_id,
3030 &body,
3031 &request.mentions,
3032 timestamp,
3033 nonce.clone().into(),
3034 match request.reply_to_message_id {
3035 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3036 None => None,
3037 },
3038 )
3039 .await?;
3040 let message = proto::ChannelMessage {
3041 sender_id: session.user_id.to_proto(),
3042 id: message_id.to_proto(),
3043 body,
3044 mentions: request.mentions,
3045 timestamp: timestamp.unix_timestamp() as u64,
3046 nonce: Some(nonce),
3047 reply_to_message_id: request.reply_to_message_id,
3048 };
3049 broadcast(
3050 Some(session.connection_id),
3051 participant_connection_ids,
3052 |connection| {
3053 session.peer.send(
3054 connection,
3055 proto::ChannelMessageSent {
3056 channel_id: channel_id.to_proto(),
3057 message: Some(message.clone()),
3058 },
3059 )
3060 },
3061 );
3062 response.send(proto::SendChannelMessageResponse {
3063 message: Some(message),
3064 })?;
3065
3066 let pool = &*session.connection_pool().await;
3067 broadcast(
3068 None,
3069 channel_members
3070 .iter()
3071 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3072 |peer_id| {
3073 session.peer.send(
3074 peer_id.into(),
3075 proto::UpdateChannels {
3076 latest_channel_message_ids: vec![proto::ChannelMessageId {
3077 channel_id: channel_id.to_proto(),
3078 message_id: message_id.to_proto(),
3079 }],
3080 ..Default::default()
3081 },
3082 )
3083 },
3084 );
3085 send_notifications(pool, &session.peer, notifications);
3086
3087 Ok(())
3088}
3089
3090/// Delete a channel message
3091async fn remove_channel_message(
3092 request: proto::RemoveChannelMessage,
3093 response: Response<proto::RemoveChannelMessage>,
3094 session: Session,
3095) -> Result<()> {
3096 let channel_id = ChannelId::from_proto(request.channel_id);
3097 let message_id = MessageId::from_proto(request.message_id);
3098 let connection_ids = session
3099 .db()
3100 .await
3101 .remove_channel_message(channel_id, message_id, session.user_id)
3102 .await?;
3103 broadcast(Some(session.connection_id), connection_ids, |connection| {
3104 session.peer.send(connection, request.clone())
3105 });
3106 response.send(proto::Ack {})?;
3107 Ok(())
3108}
3109
3110/// Mark a channel message as read
3111async fn acknowledge_channel_message(
3112 request: proto::AckChannelMessage,
3113 session: Session,
3114) -> Result<()> {
3115 let channel_id = ChannelId::from_proto(request.channel_id);
3116 let message_id = MessageId::from_proto(request.message_id);
3117 let notifications = session
3118 .db()
3119 .await
3120 .observe_channel_message(channel_id, session.user_id, message_id)
3121 .await?;
3122 send_notifications(
3123 &*session.connection_pool().await,
3124 &session.peer,
3125 notifications,
3126 );
3127 Ok(())
3128}
3129
3130/// Mark a buffer version as synced
3131async fn acknowledge_buffer_version(
3132 request: proto::AckBufferOperation,
3133 session: Session,
3134) -> Result<()> {
3135 let buffer_id = BufferId::from_proto(request.buffer_id);
3136 session
3137 .db()
3138 .await
3139 .observe_buffer_version(
3140 buffer_id,
3141 session.user_id,
3142 request.epoch as i32,
3143 &request.version,
3144 )
3145 .await?;
3146 Ok(())
3147}
3148
3149/// Start receiving chat updates for a channel
3150async fn join_channel_chat(
3151 request: proto::JoinChannelChat,
3152 response: Response<proto::JoinChannelChat>,
3153 session: Session,
3154) -> Result<()> {
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156
3157 let db = session.db().await;
3158 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3159 .await?;
3160 let messages = db
3161 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3162 .await?;
3163 response.send(proto::JoinChannelChatResponse {
3164 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3165 messages,
3166 })?;
3167 Ok(())
3168}
3169
3170/// Stop receiving chat updates for a channel
3171async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3172 let channel_id = ChannelId::from_proto(request.channel_id);
3173 session
3174 .db()
3175 .await
3176 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3177 .await?;
3178 Ok(())
3179}
3180
3181/// Retrieve the chat history for a channel
3182async fn get_channel_messages(
3183 request: proto::GetChannelMessages,
3184 response: Response<proto::GetChannelMessages>,
3185 session: Session,
3186) -> Result<()> {
3187 let channel_id = ChannelId::from_proto(request.channel_id);
3188 let messages = session
3189 .db()
3190 .await
3191 .get_channel_messages(
3192 channel_id,
3193 session.user_id,
3194 MESSAGE_COUNT_PER_PAGE,
3195 Some(MessageId::from_proto(request.before_message_id)),
3196 )
3197 .await?;
3198 response.send(proto::GetChannelMessagesResponse {
3199 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3200 messages,
3201 })?;
3202 Ok(())
3203}
3204
3205/// Retrieve specific chat messages
3206async fn get_channel_messages_by_id(
3207 request: proto::GetChannelMessagesById,
3208 response: Response<proto::GetChannelMessagesById>,
3209 session: Session,
3210) -> Result<()> {
3211 let message_ids = request
3212 .message_ids
3213 .iter()
3214 .map(|id| MessageId::from_proto(*id))
3215 .collect::<Vec<_>>();
3216 let messages = session
3217 .db()
3218 .await
3219 .get_channel_messages_by_id(session.user_id, &message_ids)
3220 .await?;
3221 response.send(proto::GetChannelMessagesResponse {
3222 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3223 messages,
3224 })?;
3225 Ok(())
3226}
3227
3228/// Retrieve the current users notifications
3229async fn get_notifications(
3230 request: proto::GetNotifications,
3231 response: Response<proto::GetNotifications>,
3232 session: Session,
3233) -> Result<()> {
3234 let notifications = session
3235 .db()
3236 .await
3237 .get_notifications(
3238 session.user_id,
3239 NOTIFICATION_COUNT_PER_PAGE,
3240 request
3241 .before_id
3242 .map(|id| db::NotificationId::from_proto(id)),
3243 )
3244 .await?;
3245 response.send(proto::GetNotificationsResponse {
3246 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3247 notifications,
3248 })?;
3249 Ok(())
3250}
3251
3252/// Mark notifications as read
3253async fn mark_notification_as_read(
3254 request: proto::MarkNotificationRead,
3255 response: Response<proto::MarkNotificationRead>,
3256 session: Session,
3257) -> Result<()> {
3258 let database = &session.db().await;
3259 let notifications = database
3260 .mark_notification_as_read_by_id(
3261 session.user_id,
3262 NotificationId::from_proto(request.notification_id),
3263 )
3264 .await?;
3265 send_notifications(
3266 &*session.connection_pool().await,
3267 &session.peer,
3268 notifications,
3269 );
3270 response.send(proto::Ack {})?;
3271 Ok(())
3272}
3273
3274/// Get the current users information
3275async fn get_private_user_info(
3276 _request: proto::GetPrivateUserInfo,
3277 response: Response<proto::GetPrivateUserInfo>,
3278 session: Session,
3279) -> Result<()> {
3280 let db = session.db().await;
3281
3282 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3283 let user = db
3284 .get_user_by_id(session.user_id)
3285 .await?
3286 .ok_or_else(|| anyhow!("user not found"))?;
3287 let flags = db.get_user_flags(session.user_id).await?;
3288
3289 response.send(proto::GetPrivateUserInfoResponse {
3290 metrics_id,
3291 staff: user.admin,
3292 flags,
3293 })?;
3294 Ok(())
3295}
3296
3297fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3298 match message {
3299 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3300 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3301 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3302 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3303 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3304 code: frame.code.into(),
3305 reason: frame.reason,
3306 })),
3307 }
3308}
3309
3310fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3311 match message {
3312 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3313 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3314 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3315 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3316 AxumMessage::Close(frame) => {
3317 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3318 code: frame.code.into(),
3319 reason: frame.reason,
3320 }))
3321 }
3322 }
3323}
3324
3325fn notify_membership_updated(
3326 connection_pool: &ConnectionPool,
3327 result: MembershipUpdated,
3328 user_id: UserId,
3329 peer: &Peer,
3330) {
3331 let user_channels_update = proto::UpdateUserChannels {
3332 channel_memberships: result
3333 .new_channels
3334 .channel_memberships
3335 .iter()
3336 .map(|cm| proto::ChannelMembership {
3337 channel_id: cm.channel_id.to_proto(),
3338 role: cm.role.into(),
3339 })
3340 .collect(),
3341 ..Default::default()
3342 };
3343 let mut update = build_channels_update(result.new_channels, vec![]);
3344 update.delete_channels = result
3345 .removed_channels
3346 .into_iter()
3347 .map(|id| id.to_proto())
3348 .collect();
3349 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3350
3351 for connection_id in connection_pool.user_connection_ids(user_id) {
3352 peer.send(connection_id, user_channels_update.clone())
3353 .trace_err();
3354 peer.send(connection_id, update.clone()).trace_err();
3355 }
3356}
3357
3358fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3359 proto::UpdateUserChannels {
3360 channel_memberships: channels
3361 .channel_memberships
3362 .iter()
3363 .map(|m| proto::ChannelMembership {
3364 channel_id: m.channel_id.to_proto(),
3365 role: m.role.into(),
3366 })
3367 .collect(),
3368 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3369 observed_channel_message_id: channels.observed_channel_messages.clone(),
3370 ..Default::default()
3371 }
3372}
3373
3374fn build_channels_update(
3375 channels: ChannelsForUser,
3376 channel_invites: Vec<db::Channel>,
3377) -> proto::UpdateChannels {
3378 let mut update = proto::UpdateChannels::default();
3379
3380 for channel in channels.channels {
3381 update.channels.push(channel.to_proto());
3382 }
3383
3384 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3385 update.latest_channel_message_ids = channels.latest_channel_messages;
3386
3387 for (channel_id, participants) in channels.channel_participants {
3388 update
3389 .channel_participants
3390 .push(proto::ChannelParticipants {
3391 channel_id: channel_id.to_proto(),
3392 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3393 });
3394 }
3395
3396 for channel in channel_invites {
3397 update.channel_invitations.push(channel.to_proto());
3398 }
3399 for project in channels.hosted_projects {
3400 update.hosted_projects.push(project);
3401 }
3402
3403 update
3404}
3405
3406fn build_initial_contacts_update(
3407 contacts: Vec<db::Contact>,
3408 pool: &ConnectionPool,
3409) -> proto::UpdateContacts {
3410 let mut update = proto::UpdateContacts::default();
3411
3412 for contact in contacts {
3413 match contact {
3414 db::Contact::Accepted { user_id, busy } => {
3415 update.contacts.push(contact_for_user(user_id, busy, &pool));
3416 }
3417 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3418 db::Contact::Incoming { user_id } => {
3419 update
3420 .incoming_requests
3421 .push(proto::IncomingContactRequest {
3422 requester_id: user_id.to_proto(),
3423 })
3424 }
3425 }
3426 }
3427
3428 update
3429}
3430
3431fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3432 proto::Contact {
3433 user_id: user_id.to_proto(),
3434 online: pool.is_user_online(user_id),
3435 busy,
3436 }
3437}
3438
3439fn room_updated(room: &proto::Room, peer: &Peer) {
3440 broadcast(
3441 None,
3442 room.participants
3443 .iter()
3444 .filter_map(|participant| Some(participant.peer_id?.into())),
3445 |peer_id| {
3446 peer.send(
3447 peer_id.into(),
3448 proto::RoomUpdated {
3449 room: Some(room.clone()),
3450 },
3451 )
3452 },
3453 );
3454}
3455
3456fn channel_updated(
3457 channel_id: ChannelId,
3458 room: &proto::Room,
3459 channel_members: &[UserId],
3460 peer: &Peer,
3461 pool: &ConnectionPool,
3462) {
3463 let participants = room
3464 .participants
3465 .iter()
3466 .map(|p| p.user_id)
3467 .collect::<Vec<_>>();
3468
3469 broadcast(
3470 None,
3471 channel_members
3472 .iter()
3473 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3474 |peer_id| {
3475 peer.send(
3476 peer_id.into(),
3477 proto::UpdateChannels {
3478 channel_participants: vec![proto::ChannelParticipants {
3479 channel_id: channel_id.to_proto(),
3480 participant_user_ids: participants.clone(),
3481 }],
3482 ..Default::default()
3483 },
3484 )
3485 },
3486 );
3487}
3488
3489async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3490 let db = session.db().await;
3491
3492 let contacts = db.get_contacts(user_id).await?;
3493 let busy = db.is_user_busy(user_id).await?;
3494
3495 let pool = session.connection_pool().await;
3496 let updated_contact = contact_for_user(user_id, busy, &pool);
3497 for contact in contacts {
3498 if let db::Contact::Accepted {
3499 user_id: contact_user_id,
3500 ..
3501 } = contact
3502 {
3503 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3504 session
3505 .peer
3506 .send(
3507 contact_conn_id,
3508 proto::UpdateContacts {
3509 contacts: vec![updated_contact.clone()],
3510 remove_contacts: Default::default(),
3511 incoming_requests: Default::default(),
3512 remove_incoming_requests: Default::default(),
3513 outgoing_requests: Default::default(),
3514 remove_outgoing_requests: Default::default(),
3515 },
3516 )
3517 .trace_err();
3518 }
3519 }
3520 }
3521 Ok(())
3522}
3523
3524async fn leave_room_for_session(session: &Session) -> Result<()> {
3525 let mut contacts_to_update = HashSet::default();
3526
3527 let room_id;
3528 let canceled_calls_to_user_ids;
3529 let live_kit_room;
3530 let delete_live_kit_room;
3531 let room;
3532 let channel_members;
3533 let channel_id;
3534
3535 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3536 contacts_to_update.insert(session.user_id);
3537
3538 for project in left_room.left_projects.values() {
3539 project_left(project, session);
3540 }
3541
3542 room_id = RoomId::from_proto(left_room.room.id);
3543 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3544 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3545 delete_live_kit_room = left_room.deleted;
3546 room = mem::take(&mut left_room.room);
3547 channel_members = mem::take(&mut left_room.channel_members);
3548 channel_id = left_room.channel_id;
3549
3550 room_updated(&room, &session.peer);
3551 } else {
3552 return Ok(());
3553 }
3554
3555 if let Some(channel_id) = channel_id {
3556 channel_updated(
3557 channel_id,
3558 &room,
3559 &channel_members,
3560 &session.peer,
3561 &*session.connection_pool().await,
3562 );
3563 }
3564
3565 {
3566 let pool = session.connection_pool().await;
3567 for canceled_user_id in canceled_calls_to_user_ids {
3568 for connection_id in pool.user_connection_ids(canceled_user_id) {
3569 session
3570 .peer
3571 .send(
3572 connection_id,
3573 proto::CallCanceled {
3574 room_id: room_id.to_proto(),
3575 },
3576 )
3577 .trace_err();
3578 }
3579 contacts_to_update.insert(canceled_user_id);
3580 }
3581 }
3582
3583 for contact_user_id in contacts_to_update {
3584 update_user_contacts(contact_user_id, &session).await?;
3585 }
3586
3587 if let Some(live_kit) = session.live_kit_client.as_ref() {
3588 live_kit
3589 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3590 .await
3591 .trace_err();
3592
3593 if delete_live_kit_room {
3594 live_kit.delete_room(live_kit_room).await.trace_err();
3595 }
3596 }
3597
3598 Ok(())
3599}
3600
3601async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3602 let left_channel_buffers = session
3603 .db()
3604 .await
3605 .leave_channel_buffers(session.connection_id)
3606 .await?;
3607
3608 for left_buffer in left_channel_buffers {
3609 channel_buffer_updated(
3610 session.connection_id,
3611 left_buffer.connections,
3612 &proto::UpdateChannelBufferCollaborators {
3613 channel_id: left_buffer.channel_id.to_proto(),
3614 collaborators: left_buffer.collaborators,
3615 },
3616 &session.peer,
3617 );
3618 }
3619
3620 Ok(())
3621}
3622
3623fn project_left(project: &db::LeftProject, session: &Session) {
3624 for connection_id in &project.connection_ids {
3625 if project.host_user_id == session.user_id {
3626 session
3627 .peer
3628 .send(
3629 *connection_id,
3630 proto::UnshareProject {
3631 project_id: project.id.to_proto(),
3632 },
3633 )
3634 .trace_err();
3635 } else {
3636 session
3637 .peer
3638 .send(
3639 *connection_id,
3640 proto::RemoveProjectCollaborator {
3641 project_id: project.id.to_proto(),
3642 peer_id: Some(session.connection_id.into()),
3643 },
3644 )
3645 .trace_err();
3646 }
3647 }
3648}
3649
3650pub trait ResultExt {
3651 type Ok;
3652
3653 fn trace_err(self) -> Option<Self::Ok>;
3654}
3655
3656impl<T, E> ResultExt for Result<T, E>
3657where
3658 E: std::fmt::Debug,
3659{
3660 type Ok = T;
3661
3662 fn trace_err(self) -> Option<T> {
3663 match self {
3664 Ok(value) => Some(value),
3665 Err(error) => {
3666 tracing::error!("{:?}", error);
3667 None
3668 }
3669 }
3670 }
3671}